aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavid Neto <dneto@google.com>2017-06-14 10:40:27 -0400
committerDavid Neto <dneto@google.com>2017-06-14 10:41:00 -0400
commit4d55e9f66f91a62d5493c866cfeec77891b4cf55 (patch)
tree1e37258c2741503b1f2e148d571101e59b3a7fdd
parent86c6fe001dfdbd8e8b39f40aeeab4c691ac42f04 (diff)
parent7c8da66bc27cc5c4ccb6a0fa612f56c9417518ff (diff)
downloadspirv-tools-4d55e9f66f91a62d5493c866cfeec77891b4cf55.tar.gz
Merge remote-tracking branch 'aosp/upstream-master' into update-shaderc
7c8da66 mem2reg: Add pass to eliminate local loads and stores in single block. 1567cdd Don't install googletest and googlemock aa7e687 Mem2Reg: Add Local Access Chain Convert pass d71d976 Fix memory leak in ValidateBinaryUsingContextAndValidationState 66fc105 Bots print output from timed out tests e7aff80 Fixed misspelled ctest flag --output_on_failure ddf4de6 Support building on FreeBSD 3bea99d CFA: Move TraversalRoots and ComputeAugmentedCFG into CFA d6f2979 CFA: Pull in CalculateDominators df6537c DefUseManager: Fix ReplaceAllUsesWith() to update inst_to_used_ids_ 20fe946 Added extension SPV_VALIDATOR_ignore_type_decl_unique 3492cc6 Remove unused this in lambda capture dbc2049 Add SPIR-V 1.2 support, for OpenCL 2.2 eb720b2 Fix size_t conversion error on MinGW 51b6778 Update CHANGES: note fix of issue 629 bba812f Inline: Inline early return function if no returns in loop. 3eb716c Added bit stream utils f5facf8 Stats analyzer aggregates OpConstant usage b4cf371 Stats analyzer uses validator 01b2875 Avoid snprintf warning in GCC 7.1 b25b330 Inline: Create CFA class 3f90058 Update set_spec_const_default_value_test.cpp 87a3f65 Added Markov chain analysis to stats bad90d9 Inline: Change "--inline-entry-points-all" to "-exhaustive" d870dbe Inline: Fix inliner description in usage message to reflect exceptions. a107d34 Inline: Do not inline functions with multiple returns (for now) 144f59e Add bit pattern interface for setting default value for spec constants 1d8efb0 Update CHANGES with recent news 1e309af Added --compact-ids to /tools/opt b173d1c Added option --preserve-numeric-ids to tools/spirv-as 4f21640 Added statistical analysis tool (tool/stats) 72debb8 Test source language HLSL bf68c81 Support SPV_KHR_storage_buffer_storage_class 23af06c Validator support for Variable Pointer extension. 4895ace Update cap tests for SPV_KHR_16bit_storage 4087e89 Test asm,dis support for SPV_KHR_variable_pointers 11a867f Add FlattenDecoration transform 5c3c054 Group targets into folders dec3f5e Update spirv-opt to use spvtools::Optimizer afc60bb Fix optimizer on when to write the binary ad3b082 Add /EHs for targets for MSVC 4be6abe Fix spelling in SPV_AMD_gcn_shader support 58e7a3e Fix typo in method name Struct::AddMemberName ceb1d4f Avoid inlining calls to external functions 4fc9302 opt::Function::cbegin and cend are const Test: checkbuild.py on Linux; unit tests on Windows Change-Id: I1d8c460282840789c369769e2784db1f4684590c
-rw-r--r--.appveyor.yml2
-rw-r--r--.travis.yml2
-rw-r--r--CHANGES11
-rw-r--r--CMakeLists.txt9
-rw-r--r--examples/CMakeLists.txt20
-rw-r--r--examples/cpp-interface/CMakeLists.txt10
-rw-r--r--external/CMakeLists.txt13
-rw-r--r--include/spirv-tools/libspirv.h19
-rw-r--r--include/spirv-tools/libspirv.hpp12
-rw-r--r--include/spirv-tools/optimizer.hpp61
-rw-r--r--source/CMakeLists.txt16
-rw-r--r--source/assembly_grammar.h2
-rw-r--r--source/binary.cpp2
-rw-r--r--source/cfa.h336
-rw-r--r--source/enum_set.h2
-rw-r--r--source/ext_inst.cpp9
-rw-r--r--source/extinst.spv-amd-gcn-shader.grammar.json (renamed from source/extinst.amd-gcn-shader.grammar.json)6
-rw-r--r--source/instruction.h2
-rw-r--r--source/libspirv.cpp12
-rw-r--r--source/name_mapper.cpp2
-rw-r--r--source/opcode.cpp36
-rw-r--r--source/opcode.h6
-rw-r--r--source/operand.cpp11
-rw-r--r--source/opt/CMakeLists.txt10
-rw-r--r--source/opt/basic_block.h10
-rw-r--r--source/opt/build_module.cpp5
-rw-r--r--source/opt/build_module.h12
-rw-r--r--source/opt/compact_ids_pass.cpp57
-rw-r--r--source/opt/compact_ids_pass.h34
-rw-r--r--source/opt/def_use_manager.cpp5
-rw-r--r--source/opt/flatten_decoration_pass.cpp164
-rw-r--r--source/opt/flatten_decoration_pass.h34
-rw-r--r--source/opt/function.h10
-rw-r--r--source/opt/inline_pass.cpp248
-rw-r--r--source/opt/inline_pass.h72
-rw-r--r--source/opt/instruction.h6
-rw-r--r--source/opt/local_access_chain_convert_pass.cpp369
-rw-r--r--source/opt/local_access_chain_convert_pass.h167
-rw-r--r--source/opt/local_single_block_elim_pass.cpp344
-rw-r--r--source/opt/local_single_block_elim_pass.h153
-rw-r--r--source/opt/log.h7
-rw-r--r--source/opt/module.cpp53
-rw-r--r--source/opt/module.h9
-rw-r--r--source/opt/optimizer.cpp31
-rw-r--r--source/opt/passes.h4
-rw-r--r--source/opt/reflect.h2
-rw-r--r--source/opt/set_spec_constant_default_value_pass.cpp79
-rw-r--r--source/opt/set_spec_constant_default_value_pass.h27
-rw-r--r--source/opt/type_manager.cpp2
-rw-r--r--source/opt/types.cpp4
-rw-r--r--source/opt/types.h6
-rw-r--r--source/print.cpp2
-rw-r--r--source/spirv_constant.h2
-rw-r--r--source/spirv_definition.h2
-rw-r--r--source/spirv_stats.cpp217
-rw-r--r--source/spirv_stats.h88
-rw-r--r--source/spirv_target_env.cpp9
-rw-r--r--source/table.cpp1
-rw-r--r--source/table.h4
-rw-r--r--source/text.cpp65
-rw-r--r--source/text_handler.cpp39
-rw-r--r--source/text_handler.h13
-rw-r--r--source/util/bit_stream.cpp387
-rw-r--r--source/util/bit_stream.h378
-rw-r--r--source/util/string_utils.cpp2
-rw-r--r--source/val/basic_block.h2
-rw-r--r--source/val/function.cpp94
-rw-r--r--source/val/function.h2
-rw-r--r--source/val/instruction.h16
-rw-r--r--source/val/validation_state.cpp8
-rw-r--r--source/val/validation_state.h8
-rw-r--r--source/validate.cpp18
-rw-r--r--source/validate.h55
-rw-r--r--source/validate_cfg.cpp135
-rw-r--r--source/validate_id.cpp29
-rw-r--r--source/validate_instruction.cpp3
-rw-r--r--source/validate_type_unique.cpp6
-rw-r--r--test/CMakeLists.txt12
-rw-r--r--test/binary_header_get_test.cpp2
-rw-r--r--test/binary_to_text_test.cpp37
-rw-r--r--test/bit_stream.cpp1138
-rw-r--r--test/enum_string_mapping_test.cpp10
-rw-r--r--test/operand_capabilities_test.cpp3
-rw-r--r--test/opt/CMakeLists.txt25
-rw-r--r--test/opt/compact_ids_test.cpp90
-rw-r--r--test/opt/flatten_decoration_test.cpp234
-rw-r--r--test/opt/inline_test.cpp244
-rw-r--r--test/opt/local_access_chain_convert_test.cpp422
-rw-r--r--test/opt/local_single_block_elim.cpp469
-rw-r--r--test/opt/optimizer_test.cpp109
-rw-r--r--test/opt/pass_fixture.h31
-rw-r--r--test/opt/set_spec_const_default_value_test.cpp537
-rw-r--r--test/preserve_numeric_ids_test.cpp158
-rw-r--r--test/scripts/test_compact_ids.py102
-rw-r--r--test/stats/CMakeLists.txt31
-rw-r--r--test/stats/stats_aggregate_test.cpp436
-rw-r--r--test/stats/stats_analyzer_test.cpp172
-rw-r--r--test/target_env_test.cpp6
-rw-r--r--test/text_to_binary.debug_test.cpp2
-rw-r--r--test/text_to_binary.extension_test.cpp58
-rw-r--r--test/text_to_binary.type_declaration_test.cpp1
-rw-r--r--test/unit_spirv.h2
-rw-r--r--test/val/val_id_test.cpp219
-rw-r--r--test/val/val_type_unique_test.cpp18
-rw-r--r--tools/CMakeLists.txt11
-rw-r--r--tools/as/as.cpp22
-rw-r--r--tools/cfg/cfg.cpp6
-rw-r--r--tools/dis/dis.cpp6
-rw-r--r--tools/opt/opt.cpp83
-rw-r--r--tools/stats/stats.cpp158
-rw-r--r--tools/stats/stats_analyzer.cpp241
-rw-r--r--tools/stats/stats_analyzer.h52
-rw-r--r--tools/val/val.cpp12
-rwxr-xr-xutils/generate_grammar_tables.py33
114 files changed, 8791 insertions, 511 deletions
diff --git a/.appveyor.yml b/.appveyor.yml
index b8a5eaaf..4edce4e6 100644
--- a/.appveyor.yml
+++ b/.appveyor.yml
@@ -39,4 +39,4 @@ build_script:
- cmake --build . --config %CONFIGURATION%
test_script:
- - ctest -C %CONFIGURATION% --output-on-failure
+ - ctest -C %CONFIGURATION% --output-on-failure --timeout 300
diff --git a/.travis.yml b/.travis.yml
index 9a14b874..848ee47a 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -76,7 +76,7 @@ script:
else
export NPROC=`sysctl -n hw.ncpu`;
fi
- - if [[ "$BUILD_NDK" != "ON" ]]; then ctest -j${NPROC} --output_on_failure; fi
+ - if [[ "$BUILD_NDK" != "ON" ]]; then ctest -j${NPROC} --output-on-failure --timeout 300; fi
notifications:
diff --git a/CHANGES b/CHANGES
index ffe7b000..493b9758 100644
--- a/CHANGES
+++ b/CHANGES
@@ -1,7 +1,12 @@
Revision history for SPIRV-Tools
v2016.7-dev 2017-01-06
- - Optimizer: Add inlining of all function calls in entry points
+ - Add SPIR-V 1.2
+ - OpenCL 2.2 support is now based on SPIR-V 1.2
+ - Optimizer: Add inlining of all function calls in entry points.
+ - Optimizer: Add flattening of decoration groups. Fixes #602
+ - Optimizer: Add Id compaction (minimize Id bound). Fixes #624
+ - Assembler: Add option to preserve numeric ids. Fixes #625
- Add build target spirv-tools-vimsyntax to generate spvasm.vim, a SPIR-V
assembly syntax file for Vim.
- Version string: Allow overriding of wall clock timestamp with contents
@@ -16,6 +21,10 @@ v2016.7-dev 2017-01-06
header.
#548: Validator: Error when the reserved OpImageSparseSampleProj* opcodes
are used.
+ #611: spvtools::Optimizer was failing to save the module to the output
+ binary vector when all passes succeded without changes.
+ #629: The inline-entry-points-all optimization could generate invalidly
+ structured code when the inlined function had early returns.
v2016.6 2016-12-13
- Published the C++ interface for assembling, disassembling, validation, and
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 214ff917..7da83753 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -19,6 +19,7 @@ if (POLICY CMP0054)
# https://cmake.org/cmake/help/v3.1/policy/CMP0054.html
cmake_policy(SET CMP0054 NEW)
endif()
+set_property(GLOBAL PROPERTY USE_FOLDERS ON)
project(spirv-tools)
enable_testing()
@@ -36,6 +37,8 @@ elseif("${CMAKE_SYSTEM_NAME}" STREQUAL "Darwin")
add_definitions(-DSPIRV_MAC)
elseif("${CMAKE_SYSTEM_NAME}" STREQUAL "Android")
add_definitions(-DSPIRV_ANDROID)
+elseif("${CMAKE_SYSTEM_NAME}" STREQUAL "FreeBSD")
+ add_definitions(-DSPIRV_FREEBSD)
else()
message(FATAL_ERROR "Your platform '${CMAKE_SYSTEM_NAME}' is not supported!")
endif()
@@ -114,6 +117,12 @@ function(spvtools_default_compile_options TARGET)
endif()
endif()
+ if (MSVC)
+ # Specify /EHs for exception handling. This makes using SPIRV-Tools as
+ # dependencies in other projects easier.
+ target_compile_options(${TARGET} PRIVATE /EHs)
+ endif()
+
# For MinGW cross compile, statically link to the C++ runtime.
# But it still depends on MSVCRT.dll.
if (${CMAKE_SYSTEM_NAME} MATCHES "Windows")
diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt
index 1dfaf875..fd627cbd 100644
--- a/examples/CMakeLists.txt
+++ b/examples/CMakeLists.txt
@@ -12,4 +12,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+# Add a SPIR-V Tools example. Signature:
+# add_spvtools_example(
+# TARGET target_name
+# SRCS src_file1.cpp src_file2.cpp
+# LIBS lib_target1 lib_target2
+# )
+function(add_spvtools_example)
+ if (NOT ${SPIRV_SKIP_EXECUTABLES})
+ set(one_value_args TARGET)
+ set(multi_value_args SRCS LIBS)
+ cmake_parse_arguments(
+ ARG "" "${one_value_args}" "${multi_value_args}" ${ARGN})
+
+ add_executable(${ARG_TARGET} ${ARG_SRCS})
+ spvtools_default_compile_options(${ARG_TARGET})
+ target_link_libraries(${ARG_TARGET} PRIVATE ${ARG_LIBS})
+ set_property(TARGET ${ARG_TARGET} PROPERTY FOLDER "SPIRV-Tools examples")
+ endif()
+endfunction()
+
add_subdirectory(cpp-interface)
diff --git a/examples/cpp-interface/CMakeLists.txt b/examples/cpp-interface/CMakeLists.txt
index 14f99376..d050b075 100644
--- a/examples/cpp-interface/CMakeLists.txt
+++ b/examples/cpp-interface/CMakeLists.txt
@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-if (NOT ${SPIRV_SKIP_EXECUTABLES})
- add_executable(spirv-tools-cpp-example main.cpp)
- spvtools_default_compile_options(spirv-tools-cpp-example)
- target_link_libraries(spirv-tools-cpp-example PRIVATE SPIRV-Tools-opt)
-endif()
+add_spvtools_example(
+ TARGET spirv-tools-cpp-example
+ SRCS main.cpp
+ LIBS SPIRV-Tools-opt
+) \ No newline at end of file
diff --git a/external/CMakeLists.txt b/external/CMakeLists.txt
index 182724ed..e710ddda 100644
--- a/external/CMakeLists.txt
+++ b/external/CMakeLists.txt
@@ -40,7 +40,18 @@ if (NOT ${SPIRV_SKIP_EXECUTABLES})
"Use shared (DLL) run-time lib even when Google Test is built as static lib."
ON)
endif()
- add_subdirectory(${GMOCK_DIR})
+ add_subdirectory(${GMOCK_DIR} EXCLUDE_FROM_ALL)
endif()
endif()
+ if (TARGET gmock)
+ set(GTEST_TARGETS
+ gtest
+ gtest_main
+ gmock
+ gmock_main
+ )
+ foreach(target ${GTEST_TARGETS})
+ set_property(TARGET ${target} PROPERTY FOLDER GoogleTest)
+ endforeach()
+ endif()
endif()
diff --git a/include/spirv-tools/libspirv.h b/include/spirv-tools/libspirv.h
index 1d7247fd..73e72cfe 100644
--- a/include/spirv-tools/libspirv.h
+++ b/include/spirv-tools/libspirv.h
@@ -244,6 +244,15 @@ typedef enum spv_number_kind_t {
SPV_NUMBER_FLOATING,
} spv_number_kind_t;
+typedef enum spv_text_to_binary_options_t {
+ SPV_TEXT_TO_BINARY_OPTION_NONE = SPV_BIT(0),
+ // Numeric IDs in the binary will have the same values as in the source.
+ // Non-numeric IDs are allocated by filling in the gaps, starting with 1
+ // and going up.
+ SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS = SPV_BIT(1),
+ SPV_FORCE_32_BIT_ENUM(spv_text_to_binary_options_t)
+} spv_text_to_binary_options_t;
+
typedef enum spv_binary_to_text_options_t {
SPV_BINARY_TO_TEXT_OPTION_NONE = SPV_BIT(0),
SPV_BINARY_TO_TEXT_OPTION_PRINT = SPV_BIT(1),
@@ -369,7 +378,8 @@ typedef enum {
SPV_ENV_OPENGL_4_2, // OpenGL 4.2 plus GL_ARB_gl_spirv, latest revisions.
SPV_ENV_OPENGL_4_3, // OpenGL 4.3 plus GL_ARB_gl_spirv, latest revisions.
// There is no variant for OpenGL 4.4.
- SPV_ENV_OPENGL_4_5, // OpenGL 4.5 plus GL_ARB_gl_spirv, latest revisions.
+ SPV_ENV_OPENGL_4_5, // OpenGL 4.5 plus GL_ARB_gl_spirv, latest revisions.
+ SPV_ENV_UNIVERSAL_1_2, // SPIR-V 1.2, latest revision, no other restrictions.
} spv_target_env;
// SPIR-V Validator can be parameterized with the following Universal Limits.
@@ -416,6 +426,13 @@ spv_result_t spvTextToBinary(const spv_const_context context, const char* text,
const size_t length, spv_binary* binary,
spv_diagnostic* diagnostic);
+// Encodes the given SPIR-V assembly text to its binary representation. Same as
+// spvTextToBinary but with options. The options parameter is a bit field of
+// spv_text_to_binary_options_t.
+spv_result_t spvTextToBinaryWithOptions(
+ const spv_const_context context, const char* text, const size_t length,
+ const uint32_t options, spv_binary* binary, spv_diagnostic* diagnostic);
+
// Frees an allocated text stream. This is a no-op if the text parameter
// is a null pointer.
void spvTextDestroy(spv_text text);
diff --git a/include/spirv-tools/libspirv.hpp b/include/spirv-tools/libspirv.hpp
index dcd85a57..f82c1348 100644
--- a/include/spirv-tools/libspirv.hpp
+++ b/include/spirv-tools/libspirv.hpp
@@ -56,6 +56,9 @@ class ValidatorOptions {
class SpirvTools {
public:
enum {
+ // Default assembling option used by assemble():
+ kDefaultAssembleOption = SPV_TEXT_TO_BINARY_OPTION_NONE,
+
// Default disassembling option used by Disassemble():
// * Avoid prefix comments from decoding the SPIR-V module header, and
// * Use friendly names for variables.
@@ -86,11 +89,13 @@ class SpirvTools {
// Assembles the given assembly |text| and writes the result to |binary|.
// Returns true on successful assembling. |binary| will be kept untouched if
// assembling is unsuccessful.
- bool Assemble(const std::string& text, std::vector<uint32_t>* binary) const;
+ bool Assemble(const std::string& text, std::vector<uint32_t>* binary,
+ uint32_t options = kDefaultAssembleOption) const;
// |text_size| specifies the number of bytes in |text|. A terminating null
// character is not required to present in |text| as long as |text| is valid.
bool Assemble(const char* text, size_t text_size,
- std::vector<uint32_t>* binary) const;
+ std::vector<uint32_t>* binary,
+ uint32_t options = kDefaultAssembleOption) const;
// Disassembles the given SPIR-V |binary| with the given |options| and writes
// the assembly to |text|. Returns ture on successful disassembling. |text|
@@ -109,7 +114,8 @@ class SpirvTools {
// |binary_size| specifies the number of words in |binary|.
bool Validate(const uint32_t* binary, size_t binary_size) const;
// Like the previous overload, but takes an options object.
- bool Validate(const uint32_t* binary, size_t binary_size, const ValidatorOptions& options) const;
+ bool Validate(const uint32_t* binary, size_t binary_size,
+ const ValidatorOptions& options) const;
private:
struct Impl; // Opaque struct for holding the data fields used by this class.
diff --git a/include/spirv-tools/optimizer.hpp b/include/spirv-tools/optimizer.hpp
index 68bb6956..099a9445 100644
--- a/include/spirv-tools/optimizer.hpp
+++ b/include/spirv-tools/optimizer.hpp
@@ -102,13 +102,31 @@ Optimizer::PassToken CreateNullPass();
// Section 3.32.2 of the SPIR-V spec) of the SPIR-V module to be optimized.
Optimizer::PassToken CreateStripDebugInfoPass();
-// Creates a set-spec-constant-default-value pass.
+// Creates a set-spec-constant-default-value pass from a mapping from spec-ids
+// to the default values in the form of string.
// A set-spec-constant-default-value pass sets the default values for the
// spec constants that have SpecId decorations (i.e., those defined by
// OpSpecConstant{|True|False} instructions).
Optimizer::PassToken CreateSetSpecConstantDefaultValuePass(
const std::unordered_map<uint32_t, std::string>& id_value_map);
+// Creates a set-spec-constant-default-value pass from a mapping from spec-ids
+// to the default values in the form of bit pattern.
+// A set-spec-constant-default-value pass sets the default values for the
+// spec constants that have SpecId decorations (i.e., those defined by
+// OpSpecConstant{|True|False} instructions).
+Optimizer::PassToken CreateSetSpecConstantDefaultValuePass(
+ const std::unordered_map<uint32_t, std::vector<uint32_t>>& id_value_map);
+
+// Creates a flatten-decoration pass.
+// A flatten-decoration pass replaces grouped decorations with equivalent
+// ungrouped decorations. That is, it replaces each OpDecorationGroup
+// instruction and associated OpGroupDecorate and OpGroupMemberDecorate
+// instructions with equivalent OpDecorate and OpMemberDecorate instructions.
+// The pass does not attempt to preserve debug information for instructions
+// it removes.
+Optimizer::PassToken CreateFlattenDecorationPass();
+
// Creates a freeze-spec-constant-value pass.
// A freeze-spec-constant pass specializes the value of spec constants to
// their default values. This pass only processes the spec constants that have
@@ -175,6 +193,47 @@ Optimizer::PassToken CreateEliminateDeadConstantPass();
// size or runtime performance. Functions that are not designated as entry
// points are not changed.
Optimizer::PassToken CreateInlinePass();
+
+// Creates a single-block local variable load/store elimination pass.
+// For every entry point function, do single block memory optimization of
+// function variables referenced only with non-access-chain loads and stores.
+// For each targeted variable load, if previous store to that variable in the
+// block, replace the load's result id with the value id of the store.
+// If previous load within the block, replace the current load's result id
+// with the previous load's result id. In either case, delete the current
+// load. Finally, check if any remaining stores are useless, and delete store
+// and variable if possible.
+//
+// The presence of access chain references and function calls can inhibit
+// the above optimization.
+//
+// Only modules with logical addressing are currently processed.
+//
+// This pass is most effective if preceeded by Inlining and
+// LocalAccessChainConvert. This pass will reduce the work needed to be done
+// by LocalSingleStoreElim and LocalSSARewrite.
+Optimizer::PassToken CreateLocalSingleBlockLoadStoreElimPass();
+
+// Creates a local access chain conversion pass.
+// A local access chain conversion pass identifies all function scope
+// variables which are accessed only with loads, stores and access chains
+// with constant indices. It then converts all loads and stores of such
+// variables into equivalent sequences of loads, stores, extracts and inserts.
+//
+// This pass only processes entry point functions. It currently only converts
+// non-nested, non-ptr access chains. It does not process modules with
+// non-32-bit integer types present. Optional memory access options on loads
+// and stores are ignored as we are only processing function scope variables.
+//
+// This pass unifies access to these variables to a single mode and simplifies
+// subsequent analysis and elimination of these variables along with their
+// loads and stores allowing values to propagate to their points of use where
+// possible.
+Optimizer::PassToken CreateLocalAccessChainConvertPass();
+
+// Creates a compact ids pass.
+// The pass remaps result ids to a compact and gapless range starting from %1.
+Optimizer::PassToken CreateCompactIdsPass();
} // namespace spvtools
diff --git a/source/CMakeLists.txt b/source/CMakeLists.txt
index 55862169..8f8ca45e 100644
--- a/source/CMakeLists.txt
+++ b/source/CMakeLists.txt
@@ -106,18 +106,20 @@ macro(spvtools_vendor_tables VENDOR_TABLE)
COMMENT "Generate extended instruction tables for ${VENDOR_TABLE}.")
list(APPEND EXTINST_CPP_DEPENDS ${INSTS_FILE})
add_custom_target(spirv-tools-${VENDOR_TABLE} DEPENDS ${INSTS_FILE})
+ set_property(TARGET spirv-tools-${VENDOR_TABLE} PROPERTY FOLDER "SPIRV-Tools build")
endmacro(spvtools_vendor_tables)
spvtools_core_tables("1.0")
spvtools_core_tables("1.1")
-spvtools_enum_string_mapping("1.1")
+spvtools_core_tables("1.2")
+spvtools_enum_string_mapping("1.2")
spvtools_opencl_tables("1.0")
spvtools_glsl_tables("1.0")
-spvtools_vendor_tables("amd-gcn-shader")
+spvtools_vendor_tables("spv-amd-gcn-shader")
-spvtools_vimsyntax("1.1" "1.0")
+spvtools_vimsyntax("1.2" "1.0")
add_custom_target(spirv-tools-vimsyntax DEPENDS ${VIMSYNTAX_FILE})
-
+set_property(TARGET spirv-tools-vimsyntax PROPERTY FOLDER "SPIRV-Tools utilities")
# Extract the list of known generators from the SPIR-V XML registry file.
set(GENERATOR_INC_FILE ${spirv-tools_BINARY_DIR}/generators.inc)
@@ -181,6 +183,7 @@ add_custom_command(OUTPUT ${SPIRV_TOOLS_BUILD_VERSION_INC}
# This is not required for any dependence chain.
add_custom_target(spirv-tools-build-version
DEPENDS ${SPIRV_TOOLS_BUILD_VERSION_INC})
+set_property(TARGET spirv-tools-build-version PROPERTY FOLDER "SPIRV-Tools build")
add_subdirectory(opt)
@@ -188,11 +191,13 @@ set(SPIRV_SOURCES
${spirv-tools_SOURCE_DIR}/include/spirv-tools/libspirv.h
${CMAKE_CURRENT_SOURCE_DIR}/util/bitutils.h
+ ${CMAKE_CURRENT_SOURCE_DIR}/util/bit_stream.h
${CMAKE_CURRENT_SOURCE_DIR}/util/hex_float.h
${CMAKE_CURRENT_SOURCE_DIR}/util/parse_number.h
${CMAKE_CURRENT_SOURCE_DIR}/util/string_utils.h
${CMAKE_CURRENT_SOURCE_DIR}/assembly_grammar.h
${CMAKE_CURRENT_SOURCE_DIR}/binary.h
+ ${CMAKE_CURRENT_SOURCE_DIR}/cfa.h
${CMAKE_CURRENT_SOURCE_DIR}/diagnostic.h
${CMAKE_CURRENT_SOURCE_DIR}/enum_set.h
${CMAKE_CURRENT_SOURCE_DIR}/enum_string_mapping.h
@@ -215,6 +220,7 @@ set(SPIRV_SOURCES
${CMAKE_CURRENT_SOURCE_DIR}/text_handler.h
${CMAKE_CURRENT_SOURCE_DIR}/validate.h
+ ${CMAKE_CURRENT_SOURCE_DIR}/util/bit_stream.cpp
${CMAKE_CURRENT_SOURCE_DIR}/util/parse_number.cpp
${CMAKE_CURRENT_SOURCE_DIR}/util/string_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/assembly_grammar.cpp
@@ -233,6 +239,7 @@ set(SPIRV_SOURCES
${CMAKE_CURRENT_SOURCE_DIR}/print.cpp
${CMAKE_CURRENT_SOURCE_DIR}/software_version.cpp
${CMAKE_CURRENT_SOURCE_DIR}/spirv_endian.cpp
+ ${CMAKE_CURRENT_SOURCE_DIR}/spirv_stats.cpp
${CMAKE_CURRENT_SOURCE_DIR}/spirv_target_env.cpp
${CMAKE_CURRENT_SOURCE_DIR}/spirv_validator_options.cpp
${CMAKE_CURRENT_SOURCE_DIR}/table.cpp
@@ -276,6 +283,7 @@ target_include_directories(${SPIRV_TOOLS}
PRIVATE ${spirv-tools_BINARY_DIR}
PRIVATE ${SPIRV_HEADER_INCLUDE_DIR}
)
+set_property(TARGET ${SPIRV_TOOLS} PROPERTY FOLDER "SPIRV-Tools libraries")
install(TARGETS ${SPIRV_TOOLS}
RUNTIME DESTINATION bin
diff --git a/source/assembly_grammar.h b/source/assembly_grammar.h
index fa5b9200..ac211369 100644
--- a/source/assembly_grammar.h
+++ b/source/assembly_grammar.h
@@ -17,7 +17,7 @@
#include "operand.h"
#include "spirv-tools/libspirv.h"
-#include "spirv/1.1/spirv.h"
+#include "spirv/1.2/spirv.h"
#include "table.h"
namespace libspirv {
diff --git a/source/binary.cpp b/source/binary.cpp
index cdd2eab8..a803def7 100644
--- a/source/binary.cpp
+++ b/source/binary.cpp
@@ -27,7 +27,7 @@
#include "ext_inst.h"
#include "opcode.h"
#include "operand.h"
-#include "spirv/1.1/spirv.h"
+#include "spirv/1.2/spirv.h"
#include "spirv_constant.h"
#include "spirv_endian.h"
diff --git a/source/cfa.h b/source/cfa.h
new file mode 100644
index 00000000..70c241a0
--- /dev/null
+++ b/source/cfa.h
@@ -0,0 +1,336 @@
+// Copyright (c) 2015-2016 The Khronos Group Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SPVTOOLS_CFA_H_
+#define SPVTOOLS_CFA_H_
+
+#include <algorithm>
+#include <cassert>
+#include <functional>
+#include <map>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+using std::find;
+using std::function;
+using std::get;
+using std::pair;
+using std::unordered_map;
+using std::unordered_set;
+using std::vector;
+
+namespace spvtools {
+
+// Control Flow Analysis of control flow graphs of basic block nodes |BB|.
+template<class BB> class CFA {
+ using bb_ptr = BB*;
+ using cbb_ptr = const BB*;
+ using bb_iter = typename std::vector<BB*>::const_iterator;
+ using get_blocks_func =
+ std::function<const std::vector<BB*>*(const BB*)>;
+
+ struct block_info {
+ cbb_ptr block; ///< pointer to the block
+ bb_iter iter; ///< Iterator to the current child node being processed
+ };
+
+ /// Returns true if a block with @p id is found in the @p work_list vector
+ ///
+ /// @param[in] work_list Set of blocks visited in the the depth first traversal
+ /// of the CFG
+ /// @param[in] id The ID of the block being checked
+ ///
+ /// @return true if the edge work_list.back().block->id() => id is a back-edge
+ static bool FindInWorkList(
+ const std::vector<block_info>& work_list, uint32_t id);
+
+public:
+ /// @brief Depth first traversal starting from the \p entry BasicBlock
+ ///
+ /// This function performs a depth first traversal from the \p entry
+ /// BasicBlock and calls the pre/postorder functions when it needs to process
+ /// the node in pre order, post order. It also calls the backedge function
+ /// when a back edge is encountered.
+ ///
+ /// @param[in] entry The root BasicBlock of a CFG
+ /// @param[in] successor_func A function which will return a pointer to the
+ /// successor nodes
+ /// @param[in] preorder A function that will be called for every block in a
+ /// CFG following preorder traversal semantics
+ /// @param[in] postorder A function that will be called for every block in a
+ /// CFG following postorder traversal semantics
+ /// @param[in] backedge A function that will be called when a backedge is
+ /// encountered during a traversal
+ /// NOTE: The @p successor_func and predecessor_func each return a pointer to a
+ /// collection such that iterators to that collection remain valid for the
+ /// lifetime of the algorithm.
+ static void DepthFirstTraversal(const BB* entry,
+ get_blocks_func successor_func,
+ std::function<void(cbb_ptr)> preorder,
+ std::function<void(cbb_ptr)> postorder,
+ std::function<void(cbb_ptr, cbb_ptr)> backedge);
+
+ /// @brief Calculates dominator edges for a set of blocks
+ ///
+ /// Computes dominators using the algorithm of Cooper, Harvey, and Kennedy
+ /// "A Simple, Fast Dominance Algorithm", 2001.
+ ///
+ /// The algorithm assumes there is a unique root node (a node without
+ /// predecessors), and it is therefore at the end of the postorder vector.
+ ///
+ /// This function calculates the dominator edges for a set of blocks in the CFG.
+ /// Uses the dominator algorithm by Cooper et al.
+ ///
+ /// @param[in] postorder A vector of blocks in post order traversal order
+ /// in a CFG
+ /// @param[in] predecessor_func Function used to get the predecessor nodes of a
+ /// block
+ ///
+ /// @return the dominator tree of the graph, as a vector of pairs of nodes.
+ /// The first node in the pair is a node in the graph. The second node in the
+ /// pair is its immediate dominator in the sense of Cooper et.al., where a block
+ /// without predecessors (such as the root node) is its own immediate dominator.
+ static vector<pair<BB*, BB*>> CalculateDominators(
+ const vector<cbb_ptr>& postorder, get_blocks_func predecessor_func);
+
+ // Computes a minimal set of root nodes required to traverse, in the forward
+ // direction, the CFG represented by the given vector of blocks, and successor
+ // and predecessor functions. When considering adding two nodes, each having
+ // predecessors, favour using the one that appears earlier on the input blocks
+ // list.
+ static std::vector<BB*> TraversalRoots(
+ const std::vector<BB*>& blocks,
+ get_blocks_func succ_func,
+ get_blocks_func pred_func);
+
+ static void ComputeAugmentedCFG(
+ std::vector<BB*>& ordered_blocks,
+ BB* pseudo_entry_block,
+ BB* pseudo_exit_block,
+ std::unordered_map<const BB*, std::vector<BB*>>* augmented_successors_map,
+ std::unordered_map<const BB*, std::vector<BB*>>* augmented_predecessors_map,
+ get_blocks_func succ_func,
+ get_blocks_func pred_func);
+};
+
+template<class BB> bool CFA<BB>::FindInWorkList(const vector<block_info>& work_list,
+ uint32_t id) {
+ for (const auto b : work_list) {
+ if (b.block->id() == id) return true;
+ }
+ return false;
+}
+
+template<class BB> void CFA<BB>::DepthFirstTraversal(const BB* entry,
+ get_blocks_func successor_func,
+ function<void(cbb_ptr)> preorder,
+ function<void(cbb_ptr)> postorder,
+ function<void(cbb_ptr, cbb_ptr)> backedge) {
+ unordered_set<uint32_t> processed;
+
+ /// NOTE: work_list is the sequence of nodes from the root node to the node
+ /// being processed in the traversal
+ vector<block_info> work_list;
+ work_list.reserve(10);
+
+ work_list.push_back({ entry, begin(*successor_func(entry)) });
+ preorder(entry);
+ processed.insert(entry->id());
+
+ while (!work_list.empty()) {
+ block_info& top = work_list.back();
+ if (top.iter == end(*successor_func(top.block))) {
+ postorder(top.block);
+ work_list.pop_back();
+ }
+ else {
+ BB* child = *top.iter;
+ top.iter++;
+ if (FindInWorkList(work_list, child->id())) {
+ backedge(top.block, child);
+ }
+ if (processed.count(child->id()) == 0) {
+ preorder(child);
+ work_list.emplace_back(
+ block_info{ child, begin(*successor_func(child)) });
+ processed.insert(child->id());
+ }
+ }
+ }
+}
+
+template<class BB>
+vector<pair<BB*, BB*>> CFA<BB>::CalculateDominators(
+ const vector<cbb_ptr>& postorder, get_blocks_func predecessor_func) {
+ struct block_detail {
+ size_t dominator; ///< The index of blocks's dominator in post order array
+ size_t postorder_index; ///< The index of the block in the post order array
+ };
+ const size_t undefined_dom = postorder.size();
+
+ unordered_map<cbb_ptr, block_detail> idoms;
+ for (size_t i = 0; i < postorder.size(); i++) {
+ idoms[postorder[i]] = { undefined_dom, i };
+ }
+ idoms[postorder.back()].dominator = idoms[postorder.back()].postorder_index;
+
+ bool changed = true;
+ while (changed) {
+ changed = false;
+ for (auto b = postorder.rbegin() + 1; b != postorder.rend(); ++b) {
+ const vector<BB*>& predecessors = *predecessor_func(*b);
+ // Find the first processed/reachable predecessor that is reachable
+ // in the forward traversal.
+ auto res = find_if(begin(predecessors), end(predecessors),
+ [&idoms, undefined_dom](BB* pred) {
+ return idoms.count(pred) &&
+ idoms[pred].dominator != undefined_dom;
+ });
+ if (res == end(predecessors)) continue;
+ const BB* idom = *res;
+ size_t idom_idx = idoms[idom].postorder_index;
+
+ // all other predecessors
+ for (const auto* p : predecessors) {
+ if (idom == p) continue;
+ // Only consider nodes reachable in the forward traversal.
+ // Otherwise the intersection doesn't make sense and will never
+ // terminate.
+ if (!idoms.count(p)) continue;
+ if (idoms[p].dominator != undefined_dom) {
+ size_t finger1 = idoms[p].postorder_index;
+ size_t finger2 = idom_idx;
+ while (finger1 != finger2) {
+ while (finger1 < finger2) {
+ finger1 = idoms[postorder[finger1]].dominator;
+ }
+ while (finger2 < finger1) {
+ finger2 = idoms[postorder[finger2]].dominator;
+ }
+ }
+ idom_idx = finger1;
+ }
+ }
+ if (idoms[*b].dominator != idom_idx) {
+ idoms[*b].dominator = idom_idx;
+ changed = true;
+ }
+ }
+ }
+
+ vector<pair<bb_ptr, bb_ptr>> out;
+ for (auto idom : idoms) {
+ // NOTE: performing a const cast for convenient usage with
+ // UpdateImmediateDominators
+ out.push_back({ const_cast<BB*>(get<0>(idom)),
+ const_cast<BB*>(postorder[get<1>(idom).dominator]) });
+ }
+ return out;
+}
+
+template<class BB>
+std::vector<BB*> CFA<BB>::TraversalRoots(
+ const std::vector<BB*>& blocks,
+ get_blocks_func succ_func,
+ get_blocks_func pred_func) {
+ // The set of nodes which have been visited from any of the roots so far.
+ std::unordered_set<const BB*> visited;
+
+ auto mark_visited = [&visited](const BB* b) { visited.insert(b); };
+ auto ignore_block = [](const BB*) {};
+ auto ignore_blocks = [](const BB*, const BB*) {};
+
+
+ auto traverse_from_root = [&mark_visited, &succ_func, &ignore_block,
+ &ignore_blocks](const BB* entry) {
+ DepthFirstTraversal(
+ entry, succ_func, mark_visited, ignore_block, ignore_blocks);
+ };
+
+ std::vector<BB*> result;
+
+ // First collect nodes without predecessors.
+ for (auto block : blocks) {
+ if (pred_func(block)->empty()) {
+ assert(visited.count(block) == 0 && "Malformed graph!");
+ result.push_back(block);
+ traverse_from_root(block);
+ }
+ }
+
+ // Now collect other stranded nodes. These must be in unreachable cycles.
+ for (auto block : blocks) {
+ if (visited.count(block) == 0) {
+ result.push_back(block);
+ traverse_from_root(block);
+ }
+ }
+
+ return result;
+}
+
+template<class BB>
+void CFA<BB>::ComputeAugmentedCFG(
+ std::vector<BB*>& ordered_blocks,
+ BB* pseudo_entry_block, BB* pseudo_exit_block,
+ std::unordered_map<const BB*, std::vector<BB*>>* augmented_successors_map,
+ std::unordered_map<const BB*, std::vector<BB*>>* augmented_predecessors_map,
+ get_blocks_func succ_func,
+ get_blocks_func pred_func) {
+
+ // Compute the successors of the pseudo-entry block, and
+ // the predecessors of the pseudo exit block.
+ auto sources = TraversalRoots(ordered_blocks, succ_func, pred_func);
+
+ // For the predecessor traversals, reverse the order of blocks. This
+ // will affect the post-dominance calculation as follows:
+ // - Suppose you have blocks A and B, with A appearing before B in
+ // the list of blocks.
+ // - Also, A branches only to B, and B branches only to A.
+ // - We want to compute A as dominating B, and B as post-dominating B.
+ // By using reversed blocks for predecessor traversal roots discovery,
+ // we'll add an edge from B to the pseudo-exit node, rather than from A.
+ // All this is needed to correctly process the dominance/post-dominance
+ // constraint when A is a loop header that points to itself as its
+ // own continue target, and B is the latch block for the loop.
+ std::vector<BB*> reversed_blocks(ordered_blocks.rbegin(),
+ ordered_blocks.rend());
+ auto sinks = TraversalRoots(reversed_blocks, pred_func, succ_func);
+
+ // Wire up the pseudo entry block.
+ (*augmented_successors_map)[pseudo_entry_block] = sources;
+ for (auto block : sources) {
+ auto& augmented_preds = (*augmented_predecessors_map)[block];
+ const auto preds = pred_func(block);
+ augmented_preds.reserve(1 + preds->size());
+ augmented_preds.push_back(pseudo_entry_block);
+ augmented_preds.insert(augmented_preds.end(), preds->begin(), preds->end());
+ }
+
+ // Wire up the pseudo exit block.
+ (*augmented_predecessors_map)[pseudo_exit_block] = sinks;
+ for (auto block : sinks) {
+ auto& augmented_succ = (*augmented_successors_map)[block];
+ const auto succ = succ_func(block);
+ augmented_succ.reserve(1 + succ->size());
+ augmented_succ.push_back(pseudo_exit_block);
+ augmented_succ.insert(augmented_succ.end(), succ->begin(), succ->end());
+ }
+};
+
+} // namespace spvtools
+
+#endif // SPVTOOLS_CFA_H_
diff --git a/source/enum_set.h b/source/enum_set.h
index 0abc5941..6d3ec73b 100644
--- a/source/enum_set.h
+++ b/source/enum_set.h
@@ -21,7 +21,7 @@
#include <set>
#include <utility>
-#include "spirv/1.1/spirv.h"
+#include "spirv/1.2/spirv.h"
namespace libspirv {
diff --git a/source/ext_inst.cpp b/source/ext_inst.cpp
index 3cac7ac1..3f2b6ce6 100644
--- a/source/ext_inst.cpp
+++ b/source/ext_inst.cpp
@@ -31,8 +31,8 @@ static const spv_ext_inst_desc_t openclEntries_1_0[] = {
#include "opencl.std.insts-1.0.inc"
};
-static const spv_ext_inst_desc_t amd_gcn_shader_entries[] = {
-#include "amd-gcn-shader.insts.inc"
+static const spv_ext_inst_desc_t spv_amd_gcn_shader_entries[] = {
+#include "spv-amd-gcn-shader.insts.inc"
};
spv_result_t spvExtInstTableGet(spv_ext_inst_table* pExtInstTable,
@@ -44,8 +44,8 @@ spv_result_t spvExtInstTableGet(spv_ext_inst_table* pExtInstTable,
glslStd450Entries_1_0},
{SPV_EXT_INST_TYPE_OPENCL_STD, ARRAY_SIZE(openclEntries_1_0),
openclEntries_1_0},
- {SPV_EXT_INST_TYPE_SPV_AMD_GCN_SHADER, ARRAY_SIZE(amd_gcn_shader_entries),
- amd_gcn_shader_entries},
+ {SPV_EXT_INST_TYPE_SPV_AMD_GCN_SHADER,
+ ARRAY_SIZE(spv_amd_gcn_shader_entries), spv_amd_gcn_shader_entries},
};
static const spv_ext_inst_table_t table_1_0 = {ARRAY_SIZE(groups_1_0),
@@ -56,6 +56,7 @@ spv_result_t spvExtInstTableGet(spv_ext_inst_table* pExtInstTable,
case SPV_ENV_UNIVERSAL_1_0:
case SPV_ENV_VULKAN_1_0:
case SPV_ENV_UNIVERSAL_1_1:
+ case SPV_ENV_UNIVERSAL_1_2:
case SPV_ENV_OPENCL_2_1:
case SPV_ENV_OPENCL_2_2:
case SPV_ENV_OPENGL_4_0:
diff --git a/source/extinst.amd-gcn-shader.grammar.json b/source/extinst.spv-amd-gcn-shader.grammar.json
index 275186b6..e18251bb 100644
--- a/source/extinst.amd-gcn-shader.grammar.json
+++ b/source/extinst.spv-amd-gcn-shader.grammar.json
@@ -7,7 +7,7 @@
"operands" : [
{ "kind" : "IdRef", "name" : "'P'" }
],
- "extensions" : [ "SPV_KHR_gcn_shader" ]
+ "extensions" : [ "SPV_AMD_gcn_shader" ]
},
{
"opname" : "CubeFaceCoordAMD",
@@ -15,12 +15,12 @@
"operands" : [
{ "kind" : "IdRef", "name" : "'P'" }
],
- "extensions" : [ "SPV_KHR_gcn_shader" ]
+ "extensions" : [ "SPV_AMD_gcn_shader" ]
},
{
"opname" : "TimeAMD",
"opcode" : 3,
- "extensions" : [ "SPV_KHR_gcn_shader" ]
+ "extensions" : [ "SPV_AMD_gcn_shader" ]
}
]
}
diff --git a/source/instruction.h b/source/instruction.h
index 5dd4d139..2afa6d45 100644
--- a/source/instruction.h
+++ b/source/instruction.h
@@ -19,7 +19,7 @@
#include <vector>
#include "spirv-tools/libspirv.h"
-#include "spirv/1.1/spirv.h"
+#include "spirv/1.2/spirv.h"
// Describes an instruction.
struct spv_instruction_t {
diff --git a/source/libspirv.cpp b/source/libspirv.cpp
index e390ffe8..77e03e57 100644
--- a/source/libspirv.cpp
+++ b/source/libspirv.cpp
@@ -39,15 +39,17 @@ void SpirvTools::SetMessageConsumer(MessageConsumer consumer) {
}
bool SpirvTools::Assemble(const std::string& text,
- std::vector<uint32_t>* binary) const {
- return Assemble(text.data(), text.size(), binary);
+ std::vector<uint32_t>* binary,
+ uint32_t options) const {
+ return Assemble(text.data(), text.size(), binary, options);
}
bool SpirvTools::Assemble(const char* text, const size_t text_size,
- std::vector<uint32_t>* binary) const {
+ std::vector<uint32_t>* binary,
+ uint32_t options) const {
spv_binary spvbinary = nullptr;
- spv_result_t status =
- spvTextToBinary(impl_->context, text, text_size, &spvbinary, nullptr);
+ spv_result_t status = spvTextToBinaryWithOptions(
+ impl_->context, text, text_size, options, &spvbinary, nullptr);
if (status == SPV_SUCCESS) {
binary->assign(spvbinary->code, spvbinary->code + spvbinary->wordCount);
}
diff --git a/source/name_mapper.cpp b/source/name_mapper.cpp
index eaf5e07e..1accc943 100644
--- a/source/name_mapper.cpp
+++ b/source/name_mapper.cpp
@@ -23,7 +23,7 @@
#include <unordered_set>
#include "spirv-tools/libspirv.h"
-#include "spirv/1.1/spirv.h"
+#include "spirv/1.2/spirv.h"
#include "parsed_operand.h"
diff --git a/source/opcode.cpp b/source/opcode.cpp
index 9f51b957..c13e4ef1 100644
--- a/source/opcode.cpp
+++ b/source/opcode.cpp
@@ -36,6 +36,9 @@ const spv_opcode_desc_t opcodeTableEntries_1_0[] = {
const spv_opcode_desc_t opcodeTableEntries_1_1[] = {
#include "core.insts-1.1.inc"
};
+const spv_opcode_desc_t opcodeTableEntries_1_2[] = {
+#include "core.insts-1.2.inc"
+};
// Represents a vendor tool entry in the SPIR-V XML Regsitry.
struct VendorTool {
@@ -83,6 +86,8 @@ spv_result_t spvOpcodeTableGet(spv_opcode_table* pInstTable,
ARRAY_SIZE(opcodeTableEntries_1_0), opcodeTableEntries_1_0};
static const spv_opcode_table_t table_1_1 = {
ARRAY_SIZE(opcodeTableEntries_1_1), opcodeTableEntries_1_1};
+ static const spv_opcode_table_t table_1_2 = {
+ ARRAY_SIZE(opcodeTableEntries_1_2), opcodeTableEntries_1_2};
switch (env) {
case SPV_ENV_UNIVERSAL_1_0:
@@ -96,9 +101,12 @@ spv_result_t spvOpcodeTableGet(spv_opcode_table* pInstTable,
*pInstTable = &table_1_0;
return SPV_SUCCESS;
case SPV_ENV_UNIVERSAL_1_1:
- case SPV_ENV_OPENCL_2_2:
*pInstTable = &table_1_1;
return SPV_SUCCESS;
+ case SPV_ENV_UNIVERSAL_1_2:
+ case SPV_ENV_OPENCL_2_2:
+ *pInstTable = &table_1_2;
+ return SPV_SUCCESS;
}
assert(0 && "Unknown spv_target_env in spvOpcodeTableGet()");
return SPV_ERROR_INVALID_TABLE;
@@ -165,9 +173,9 @@ void spvInstructionCopy(const uint32_t* words, const SpvOp opcode,
const char* spvOpcodeString(const SpvOp opcode) {
// Use the latest SPIR-V version, which should be backward-compatible with all
// previous ones.
- for (uint32_t i = 0; i < ARRAY_SIZE(opcodeTableEntries_1_1); ++i) {
- if (opcodeTableEntries_1_1[i].opcode == opcode)
- return opcodeTableEntries_1_1[i].name;
+ for (uint32_t i = 0; i < ARRAY_SIZE(opcodeTableEntries_1_2); ++i) {
+ if (opcodeTableEntries_1_2[i].opcode == opcode)
+ return opcodeTableEntries_1_2[i].name;
}
assert(0 && "Unreachable!");
return "unknown";
@@ -230,6 +238,26 @@ int32_t spvOpcodeIsComposite(const SpvOp opcode) {
}
}
+bool spvOpcodeReturnsLogicalVariablePointer(const SpvOp opcode) {
+ switch (opcode) {
+ case SpvOpVariable:
+ case SpvOpAccessChain:
+ case SpvOpInBoundsAccessChain:
+ case SpvOpFunctionParameter:
+ case SpvOpImageTexelPointer:
+ case SpvOpCopyObject:
+ case SpvOpSelect:
+ case SpvOpPhi:
+ case SpvOpFunctionCall:
+ case SpvOpPtrAccessChain:
+ case SpvOpLoad:
+ case SpvOpConstantNull:
+ return true;
+ default:
+ return false;
+ }
+}
+
int32_t spvOpcodeReturnsLogicalPointer(const SpvOp opcode) {
switch (opcode) {
case SpvOpVariable:
diff --git a/source/opcode.h b/source/opcode.h
index 9742cbca..4e06efdf 100644
--- a/source/opcode.h
+++ b/source/opcode.h
@@ -17,7 +17,7 @@
#include "instruction.h"
#include "spirv-tools/libspirv.h"
-#include "spirv/1.1/spirv.h"
+#include "spirv/1.2/spirv.h"
#include "table.h"
// Returns the name of a registered SPIR-V generator as a null-terminated
@@ -79,6 +79,10 @@ int32_t spvOpcodeIsComposite(const SpvOp opcode);
// addressing model. Returns zero if false, non-zero otherwise.
int32_t spvOpcodeReturnsLogicalPointer(const SpvOp opcode);
+// Returns whether the given opcode could result in a pointer or a variable
+// pointer when using the logical addressing model.
+bool spvOpcodeReturnsLogicalVariablePointer(const SpvOp opcode);
+
// Determines if the given opcode generates a type. Returns zero if false,
// non-zero otherwise.
int32_t spvOpcodeGeneratesType(SpvOp opcode);
diff --git a/source/operand.cpp b/source/operand.cpp
index 621c7430..cf234e77 100644
--- a/source/operand.cpp
+++ b/source/operand.cpp
@@ -26,6 +26,9 @@ namespace v1_0 {
namespace v1_1 {
#include "operand.kinds-1.1.inc"
} // namespace v1_1
+namespace v1_2 {
+#include "operand.kinds-1.2.inc"
+} // namespace v1_2
spv_result_t spvOperandTableGet(spv_operand_table* pOperandTable,
spv_target_env env) {
@@ -37,6 +40,9 @@ spv_result_t spvOperandTableGet(spv_operand_table* pOperandTable,
static const spv_operand_table_t table_1_1 = {
ARRAY_SIZE(v1_1::pygen_variable_OperandInfoTable),
v1_1::pygen_variable_OperandInfoTable};
+ static const spv_operand_table_t table_1_2 = {
+ ARRAY_SIZE(v1_2::pygen_variable_OperandInfoTable),
+ v1_2::pygen_variable_OperandInfoTable};
switch (env) {
case SPV_ENV_UNIVERSAL_1_0:
@@ -50,9 +56,12 @@ spv_result_t spvOperandTableGet(spv_operand_table* pOperandTable,
*pOperandTable = &table_1_0;
return SPV_SUCCESS;
case SPV_ENV_UNIVERSAL_1_1:
- case SPV_ENV_OPENCL_2_2:
*pOperandTable = &table_1_1;
return SPV_SUCCESS;
+ case SPV_ENV_UNIVERSAL_1_2:
+ case SPV_ENV_OPENCL_2_2:
+ *pOperandTable = &table_1_2;
+ return SPV_SUCCESS;
}
assert(0 && "Unknown spv_target_env in spvOperandTableGet()");
return SPV_ERROR_INVALID_TABLE;
diff --git a/source/opt/CMakeLists.txt b/source/opt/CMakeLists.txt
index 55949c83..64fdc025 100644
--- a/source/opt/CMakeLists.txt
+++ b/source/opt/CMakeLists.txt
@@ -14,15 +14,19 @@
add_library(SPIRV-Tools-opt
basic_block.h
build_module.h
+ compact_ids_pass.h
constants.h
def_use_manager.h
eliminate_dead_constant_pass.h
+ flatten_decoration_pass.h
function.h
fold_spec_constant_op_and_composite_pass.h
freeze_spec_constant_value_pass.h
inline_pass.h
instruction.h
ir_loader.h
+ local_access_chain_convert_pass.h
+ local_single_block_elim_pass.h
log.h
module.h
null_pass.h
@@ -38,14 +42,18 @@ add_library(SPIRV-Tools-opt
basic_block.cpp
build_module.cpp
+ compact_ids_pass.cpp
def_use_manager.cpp
eliminate_dead_constant_pass.cpp
+ flatten_decoration_pass.cpp
function.cpp
fold_spec_constant_op_and_composite_pass.cpp
freeze_spec_constant_value_pass.cpp
inline_pass.cpp
instruction.cpp
ir_loader.cpp
+ local_access_chain_convert_pass.cpp
+ local_single_block_elim_pass.cpp
module.cpp
set_spec_constant_default_value_pass.cpp
optimizer.cpp
@@ -66,6 +74,8 @@ target_include_directories(SPIRV-Tools-opt
target_link_libraries(SPIRV-Tools-opt
PUBLIC ${SPIRV_TOOLS})
+set_property(TARGET SPIRV-Tools-opt PROPERTY FOLDER "SPIRV-Tools libraries")
+
install(TARGETS SPIRV-Tools-opt
RUNTIME DESTINATION bin
LIBRARY DESTINATION lib
diff --git a/source/opt/basic_block.h b/source/opt/basic_block.h
index 05258a9f..73249051 100644
--- a/source/opt/basic_block.h
+++ b/source/opt/basic_block.h
@@ -48,12 +48,16 @@ class BasicBlock {
Instruction& Label() { return *label_; }
// Returns the id of the label at the top of this block
- inline uint32_t label_id() const { return label_->result_id(); }
+ inline uint32_t id() const { return label_->result_id(); }
iterator begin() { return iterator(&insts_, insts_.begin()); }
iterator end() { return iterator(&insts_, insts_.end()); }
- const_iterator cbegin() { return const_iterator(&insts_, insts_.cbegin()); }
- const_iterator cend() { return const_iterator(&insts_, insts_.cend()); }
+ const_iterator cbegin() const {
+ return const_iterator(&insts_, insts_.cbegin());
+ }
+ const_iterator cend() const {
+ return const_iterator(&insts_, insts_.cend());
+ }
// Runs the given function |f| on each instruction in this basic block, and
// optionally on the debug line instructions that might precede them.
diff --git a/source/opt/build_module.cpp b/source/opt/build_module.cpp
index c1daea5b..766dbb52 100644
--- a/source/opt/build_module.cpp
+++ b/source/opt/build_module.cpp
@@ -64,11 +64,12 @@ std::unique_ptr<ir::Module> BuildModule(spv_target_env env,
std::unique_ptr<ir::Module> BuildModule(spv_target_env env,
MessageConsumer consumer,
- const std::string& text) {
+ const std::string& text,
+ uint32_t assemble_options) {
SpirvTools t(env);
t.SetMessageConsumer(consumer);
std::vector<uint32_t> binary;
- if (!t.Assemble(text, &binary)) return nullptr;
+ if (!t.Assemble(text, &binary, assemble_options)) return nullptr;
return BuildModule(env, consumer, binary.data(), binary.size());
}
diff --git a/source/opt/build_module.h b/source/opt/build_module.h
index d396a3a9..c27b1aeb 100644
--- a/source/opt/build_module.h
+++ b/source/opt/build_module.h
@@ -27,16 +27,16 @@ namespace spvtools {
// specifies number of words in |binary|. The |binary| will be decoded
// according to the given target |env|. Returns nullptr if erors occur and
// sends the errors to |consumer|.
-std::unique_ptr<ir::Module> BuildModule(spv_target_env env,
- MessageConsumer consumer,
- const uint32_t* binary, size_t size);
+std::unique_ptr<ir::Module> BuildModule(
+ spv_target_env env, MessageConsumer consumer, const uint32_t* binary,
+ size_t size);
// Builds and returns an ir::Module from the given SPIR-V assembly |text|.
// The |text| will be encoded according to the given target |env|. Returns
// nullptr if erors occur and sends the errors to |consumer|.
-std::unique_ptr<ir::Module> BuildModule(spv_target_env env,
- MessageConsumer consumer,
- const std::string& text);
+std::unique_ptr<ir::Module> BuildModule(
+ spv_target_env env, MessageConsumer consumer, const std::string& text,
+ uint32_t assemble_options = SpirvTools::kDefaultAssembleOption);
} // namespace spvtools
diff --git a/source/opt/compact_ids_pass.cpp b/source/opt/compact_ids_pass.cpp
new file mode 100644
index 00000000..39792e86
--- /dev/null
+++ b/source/opt/compact_ids_pass.cpp
@@ -0,0 +1,57 @@
+// Copyright (c) 2017 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "compact_ids_pass.h"
+
+#include <cassert>
+#include <unordered_map>
+
+namespace spvtools {
+namespace opt {
+
+using ir::Instruction;
+using ir::Operand;
+
+Pass::Status CompactIdsPass::Process(ir::Module* module) {
+ bool modified = false;
+ std::unordered_map<uint32_t, uint32_t> result_id_mapping;
+
+ module->ForEachInst([&result_id_mapping, &modified] (Instruction* inst) {
+ auto operand = inst->begin();
+ while (operand != inst->end()) {
+ if (spvIsIdType(operand->type)) {
+ assert(operand->words.size() == 1);
+ uint32_t& id = operand->words[0];
+ auto it = result_id_mapping.find(id);
+ if (it == result_id_mapping.end()) {
+ const uint32_t new_id =
+ static_cast<uint32_t>(result_id_mapping.size()) + 1;
+ const auto insertion_result = result_id_mapping.emplace(id, new_id);
+ it = insertion_result.first;
+ assert(insertion_result.second);
+ }
+ if (id != it->second) {
+ modified = true;
+ id = it->second;
+ }
+ }
+ ++operand;
+ }
+ }, true);
+
+ return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
+}
+
+} // namespace opt
+} // namespace spvtools
diff --git a/source/opt/compact_ids_pass.h b/source/opt/compact_ids_pass.h
new file mode 100644
index 00000000..41918dd8
--- /dev/null
+++ b/source/opt/compact_ids_pass.h
@@ -0,0 +1,34 @@
+// Copyright (c) 2017 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef LIBSPIRV_OPT_COMPACT_IDS_PASS_H_
+#define LIBSPIRV_OPT_COMPACT_IDS_PASS_H_
+
+#include "module.h"
+#include "pass.h"
+
+namespace spvtools {
+namespace opt {
+
+// See optimizer.hpp for documentation.
+class CompactIdsPass : public Pass {
+ public:
+ const char* name() const override { return "compact-ids"; }
+ Status Process(ir::Module*) override;
+};
+
+} // namespace opt
+} // namespace spvtools
+
+#endif // LIBSPIRV_OPT_COMPACT_IDS_PASS_H_
diff --git a/source/opt/def_use_manager.cpp b/source/opt/def_use_manager.cpp
index 2d0e79a5..3d8fd50e 100644
--- a/source/opt/def_use_manager.cpp
+++ b/source/opt/def_use_manager.cpp
@@ -130,6 +130,11 @@ bool DefUseManager::ReplaceAllUsesWith(uint32_t before, uint32_t after) {
// Make the modification in the instruction.
it->inst->SetInOperand(in_operand_pos, {after});
}
+ // Update inst to used ids map
+ auto iter = inst_to_used_ids_.find(it->inst);
+ if (iter != inst_to_used_ids_.end())
+ for (auto uit = iter->second.begin(); uit != iter->second.end(); uit++)
+ if (*uit == before) *uit = after;
// Register the use of |after| id into id_to_uses_.
// TODO(antiagainst): de-duplication.
id_to_uses_[after].push_back({it->inst, it->operand_index});
diff --git a/source/opt/flatten_decoration_pass.cpp b/source/opt/flatten_decoration_pass.cpp
new file mode 100644
index 00000000..98bb69cf
--- /dev/null
+++ b/source/opt/flatten_decoration_pass.cpp
@@ -0,0 +1,164 @@
+// Copyright (c) 2017 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "flatten_decoration_pass.h"
+
+#include <cassert>
+#include <vector>
+#include <unordered_map>
+#include <unordered_set>
+
+namespace spvtools {
+namespace opt {
+
+using ir::Instruction;
+using ir::Operand;
+
+using Words = std::vector<uint32_t>;
+using OrderedUsesMap = std::unordered_map<uint32_t, Words>;
+
+Pass::Status FlattenDecorationPass::Process(ir::Module* module) {
+ bool modified = false;
+
+ // The target Id of OpDecorationGroup instructions.
+ // We have to track this separately from its uses, in case it
+ // has no uses.
+ std::unordered_set<uint32_t> group_ids;
+ // Maps a decoration group Id to its GroupDecorate targets, in order
+ // of appearance.
+ OrderedUsesMap normal_uses;
+ // Maps a decoration group Id to its GroupMemberDecorate targets and
+ // their indices, in of appearance.
+ OrderedUsesMap member_uses;
+
+ auto annotations = module->annotations();
+
+ // On the first pass, record each OpDecorationGroup with its ordered uses.
+ // Rely on unordered_map::operator[] to create its entries on first access.
+ for (const auto& inst : annotations) {
+ switch (inst.opcode()) {
+ case SpvOp::SpvOpDecorationGroup:
+ group_ids.insert(inst.result_id());
+ break;
+ case SpvOp::SpvOpGroupDecorate: {
+ Words& words = normal_uses[inst.GetSingleWordInOperand(0)];
+ for (uint32_t i = 1; i < inst.NumInOperandWords(); i++) {
+ words.push_back(inst.GetSingleWordInOperand(i));
+ }
+ } break;
+ case SpvOp::SpvOpGroupMemberDecorate: {
+ Words& words = member_uses[inst.GetSingleWordInOperand(0)];
+ for (uint32_t i = 1; i < inst.NumInOperandWords(); i++) {
+ words.push_back(inst.GetSingleWordInOperand(i));
+ }
+ } break;
+ default:
+ break;
+ }
+ }
+
+ // On the second pass, replace OpDecorationGroup and its uses with
+ // equivalent normal and struct member uses.
+ auto inst_iter = annotations.begin();
+ // We have to re-evaluate the end pointer
+ while (inst_iter != module->annotations().end()) {
+ // Should we replace this instruction?
+ bool replace = false;
+ switch (inst_iter->opcode()) {
+ case SpvOp::SpvOpDecorationGroup:
+ case SpvOp::SpvOpGroupDecorate:
+ case SpvOp::SpvOpGroupMemberDecorate:
+ replace = true;
+ break;
+ case SpvOp::SpvOpDecorate: {
+ // If this decoration targets a group, then replace it
+ // by sets of normal and member decorations.
+ const uint32_t group = inst_iter->GetSingleWordOperand(0);
+ const auto normal_uses_iter = normal_uses.find(group);
+ if (normal_uses_iter != normal_uses.end()) {
+ for (auto target : normal_uses[group]) {
+ std::unique_ptr<Instruction> new_inst(new Instruction(*inst_iter));
+ new_inst->SetInOperand(0, Words{target});
+ inst_iter = inst_iter.InsertBefore(std::move(new_inst));
+ ++inst_iter;
+ replace = true;
+ }
+ }
+ const auto member_uses_iter = member_uses.find(group);
+ if (member_uses_iter != member_uses.end()) {
+ const Words& member_id_pairs = (*member_uses_iter).second;
+ // The collection is a sequence of pairs.
+ assert((member_id_pairs.size() % 2) == 0);
+ for (size_t i = 0; i < member_id_pairs.size(); i += 2) {
+ // Make an OpMemberDecorate instruction for each (target, member)
+ // pair.
+ const uint32_t target = member_id_pairs[i];
+ const uint32_t member = member_id_pairs[i + 1];
+ std::vector<Operand> operands;
+ operands.push_back(Operand(SPV_OPERAND_TYPE_ID, {target}));
+ operands.push_back(
+ Operand(SPV_OPERAND_TYPE_LITERAL_INTEGER, {member}));
+ auto decoration_operands_iter = inst_iter->begin();
+ decoration_operands_iter++; // Skip the group target.
+ operands.insert(operands.end(), decoration_operands_iter,
+ inst_iter->end());
+ std::unique_ptr<Instruction> new_inst(
+ new Instruction(SpvOp::SpvOpMemberDecorate, 0, 0, operands));
+ inst_iter = inst_iter.InsertBefore(std::move(new_inst));
+ ++inst_iter;
+ replace = true;
+ }
+ }
+ // If this is an OpDecorate targeting the OpDecorationGroup itself,
+ // remove it even if that decoration group itself is not the target of
+ // any OpGroupDecorate or OpGroupMemberDecorate.
+ if (!replace && group_ids.count(group)) {
+ replace = true;
+ }
+ } break;
+ default:
+ break;
+ }
+ if (replace) {
+ inst_iter = inst_iter.Erase();
+ modified = true;
+ } else {
+ // Handle the case of decorations unrelated to decoration groups.
+ ++inst_iter;
+ }
+ }
+
+ // Remove OpName instructions which reference the removed group decorations.
+ // An OpDecorationGroup instruction might not have been used by an
+ // OpGroupDecorate or OpGroupMemberDecorate instruction.
+ if (!group_ids.empty()) {
+ for (auto debug_inst_iter = module->debug_begin();
+ debug_inst_iter != module->debug_end();) {
+ if (debug_inst_iter->opcode() == SpvOp::SpvOpName) {
+ const uint32_t target = debug_inst_iter->GetSingleWordOperand(0);
+ if (group_ids.count(target)) {
+ debug_inst_iter = debug_inst_iter.Erase();
+ modified = true;
+ } else {
+ ++debug_inst_iter;
+ }
+ }
+ }
+ }
+
+ return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
+}
+
+} // namespace opt
+} // namespace spvtools
diff --git a/source/opt/flatten_decoration_pass.h b/source/opt/flatten_decoration_pass.h
new file mode 100644
index 00000000..bcdfdc07
--- /dev/null
+++ b/source/opt/flatten_decoration_pass.h
@@ -0,0 +1,34 @@
+// Copyright (c) 2017 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef LIBSPIRV_OPT_FLATTEN_DECORATION_PASS_H_
+#define LIBSPIRV_OPT_FLATTEN_DECORATION_PASS_H_
+
+#include "module.h"
+#include "pass.h"
+
+namespace spvtools {
+namespace opt {
+
+// See optimizer.hpp for documentation.
+class FlattenDecorationPass : public Pass {
+ public:
+ const char* name() const override { return "flatten-decoration"; }
+ Status Process(ir::Module*) override;
+};
+
+} // namespace opt
+} // namespace spvtools
+
+#endif // LIBSPIRV_OPT_FLATTEN_DECORATION_PASS_H_
diff --git a/source/opt/function.h b/source/opt/function.h
index 2e0674e0..949f99a5 100644
--- a/source/opt/function.h
+++ b/source/opt/function.h
@@ -59,8 +59,12 @@ class Function {
iterator begin() { return iterator(&blocks_, blocks_.begin()); }
iterator end() { return iterator(&blocks_, blocks_.end()); }
- const_iterator cbegin() { return const_iterator(&blocks_, blocks_.cbegin()); }
- const_iterator cend() { return const_iterator(&blocks_, blocks_.cend()); }
+ const_iterator cbegin() const {
+ return const_iterator(&blocks_, blocks_.cbegin());
+ }
+ const_iterator cend() const {
+ return const_iterator(&blocks_, blocks_.cend());
+ }
// Runs the given function |f| on each instruction in this function, and
// optionally on debug line instructions that might precede them.
@@ -81,7 +85,7 @@ class Function {
std::unique_ptr<Instruction> def_inst_;
// All parameters to this function.
std::vector<std::unique_ptr<Instruction>> params_;
- // All basic blocks inside this function.
+ // All basic blocks inside this function in specification order
std::vector<std::unique_ptr<BasicBlock>> blocks_;
// The OpFunctionEnd instruction.
std::unique_ptr<Instruction> end_inst_;
diff --git a/source/opt/inline_pass.cpp b/source/opt/inline_pass.cpp
index 26dd4a37..de55688d 100644
--- a/source/opt/inline_pass.cpp
+++ b/source/opt/inline_pass.cpp
@@ -15,6 +15,7 @@
// limitations under the License.
#include "inline_pass.h"
+#include "cfa.h"
// Indices of operands in SPIR-V instructions
@@ -24,6 +25,8 @@ static const int kSpvFunctionCallArgumentId = 3;
static const int kSpvReturnValueId = 0;
static const int kSpvTypePointerStorageClass = 1;
static const int kSpvTypePointerTypeId = 2;
+static const int kSpvLoopMergeMergeBlockId = 0;
+static const int kSpvSelectionMergeMergeBlockId = 0;
namespace spvtools {
namespace opt {
@@ -55,13 +58,33 @@ uint32_t InlinePass::AddPointerToType(uint32_t type_id,
}
void InlinePass::AddBranch(uint32_t label_id,
- std::unique_ptr<ir::BasicBlock>* block_ptr) {
+ std::unique_ptr<ir::BasicBlock>* block_ptr) {
+ std::unique_ptr<ir::Instruction> newBranch(new ir::Instruction(
+ SpvOpBranch, 0, 0,
+ {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {label_id}}}));
+ (*block_ptr)->AddInstruction(std::move(newBranch));
+}
+
+void InlinePass::AddBranchCond(uint32_t cond_id, uint32_t true_id,
+ uint32_t false_id, std::unique_ptr<ir::BasicBlock>* block_ptr) {
std::unique_ptr<ir::Instruction> newBranch(new ir::Instruction(
- SpvOpBranch, 0, 0,
- {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {label_id}}}));
+ SpvOpBranchConditional, 0, 0,
+ {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {cond_id}},
+ {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {true_id}},
+ {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {false_id}}}));
(*block_ptr)->AddInstruction(std::move(newBranch));
}
+void InlinePass::AddLoopMerge(uint32_t merge_id, uint32_t continue_id,
+ std::unique_ptr<ir::BasicBlock>* block_ptr) {
+ std::unique_ptr<ir::Instruction> newLoopMerge(new ir::Instruction(
+ SpvOpLoopMerge, 0, 0,
+ {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {merge_id}},
+ {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {continue_id}},
+ {spv_operand_type_t::SPV_OPERAND_TYPE_LOOP_CONTROL, {0}}}));
+ (*block_ptr)->AddInstruction(std::move(newLoopMerge));
+}
+
void InlinePass::AddStore(uint32_t ptr_id, uint32_t val_id,
std::unique_ptr<ir::BasicBlock>* block_ptr) {
std::unique_ptr<ir::Instruction> newStore(new ir::Instruction(
@@ -84,6 +107,22 @@ std::unique_ptr<ir::Instruction> InlinePass::NewLabel(uint32_t label_id) {
return newLabel;
}
+uint32_t InlinePass::GetFalseId() {
+ if (false_id_ != 0)
+ return false_id_;
+ false_id_ = module_->GetGlobalValue(SpvOpConstantFalse);
+ if (false_id_ != 0)
+ return false_id_;
+ uint32_t boolId = module_->GetGlobalValue(SpvOpTypeBool);
+ if (boolId == 0) {
+ boolId = TakeNextId();
+ module_->AddGlobalValue(SpvOpTypeBool, boolId, 0);
+ }
+ false_id_ = TakeNextId();
+ module_->AddGlobalValue(SpvOpConstantFalse, false_id_, boolId);
+ return false_id_;
+}
+
void InlinePass::MapParams(
ir::Function* calleeFn,
ir::UptrVectorIterator<ir::Instruction> call_inst_itr,
@@ -94,7 +133,7 @@ void InlinePass::MapParams(
const uint32_t pid = cpi->result_id();
(*callee2caller)[pid] = call_inst_itr->GetSingleWordOperand(
kSpvFunctionCallArgumentId + param_idx);
- param_idx++;
+ ++param_idx;
});
}
@@ -111,7 +150,7 @@ void InlinePass::CloneAndMapLocals(
var_inst->SetResultId(newId);
(*callee2caller)[callee_var_itr->result_id()] = newId;
new_vars->push_back(std::move(var_inst));
- callee_var_itr++;
+ ++callee_var_itr;
}
}
@@ -189,6 +228,10 @@ void InlinePass::GenInlineCode(
ir::Function* calleeFn = id2function_[call_inst_itr->GetSingleWordOperand(
kSpvFunctionCallFunctionId)];
+ // Check for early returns
+ auto fi = early_return_.find(calleeFn->result_id());
+ bool earlyReturn = fi != early_return_.end();
+
// Map parameters to actual arguments.
MapParams(calleeFn, call_inst_itr, &callee2caller);
@@ -202,6 +245,8 @@ void InlinePass::GenInlineCode(
// Clone and map callee code. Copy caller block code to beginning of
// first block and end of last block.
bool prevInstWasReturn = false;
+ uint32_t singleTripLoopHeaderId = 0;
+ uint32_t singleTripLoopContinueId = 0;
uint32_t returnLabelId = 0;
bool multiBlocks = false;
const uint32_t calleeTypeId = calleeFn->type_id();
@@ -209,7 +254,9 @@ void InlinePass::GenInlineCode(
calleeFn->ForEachInst([&new_blocks, &callee2caller, &call_block_itr,
&call_inst_itr, &new_blk_ptr, &prevInstWasReturn,
&returnLabelId, &returnVarId, &calleeTypeId,
- &multiBlocks, &postCallSB, &preCallSB, this](
+ &multiBlocks, &postCallSB, &preCallSB, &earlyReturn,
+ &singleTripLoopHeaderId, &singleTripLoopContinueId,
+ this](
const ir::Instruction* cpi) {
switch (cpi->opcode()) {
case SpvOpFunction:
@@ -239,7 +286,7 @@ void InlinePass::GenInlineCode(
} else {
// First block needs to use label of original block
// but map callee label in case of phi reference.
- labelId = call_block_itr->label_id();
+ labelId = call_block_itr->id();
callee2caller[cpi->result_id()] = labelId;
firstBlock = true;
}
@@ -248,7 +295,7 @@ void InlinePass::GenInlineCode(
if (firstBlock) {
// Copy contents of original caller block up to call instruction.
for (auto cii = call_block_itr->begin(); cii != call_inst_itr;
- cii++) {
+ ++cii) {
std::unique_ptr<ir::Instruction> cp_inst(new ir::Instruction(*cii));
// Remember same-block ops for possible regeneration.
if (IsSameBlockOp(&*cp_inst)) {
@@ -257,6 +304,24 @@ void InlinePass::GenInlineCode(
}
new_blk_ptr->AddInstruction(std::move(cp_inst));
}
+ // If callee is early return function, insert header block for
+ // one-trip loop that will encompass callee code. Start postheader
+ // block.
+ if (earlyReturn) {
+ singleTripLoopHeaderId = this->TakeNextId();
+ AddBranch(singleTripLoopHeaderId, &new_blk_ptr);
+ new_blocks->push_back(std::move(new_blk_ptr));
+ new_blk_ptr.reset(new ir::BasicBlock(NewLabel(
+ singleTripLoopHeaderId)));
+ returnLabelId = this->TakeNextId();
+ singleTripLoopContinueId = this->TakeNextId();
+ AddLoopMerge(returnLabelId, singleTripLoopContinueId, &new_blk_ptr);
+ uint32_t postHeaderId = this->TakeNextId();
+ AddBranch(postHeaderId, &new_blk_ptr);
+ new_blocks->push_back(std::move(new_blk_ptr));
+ new_blk_ptr.reset(new ir::BasicBlock(NewLabel(postHeaderId)));
+ multiBlocks = true;
+ }
} else {
multiBlocks = true;
}
@@ -281,12 +346,17 @@ void InlinePass::GenInlineCode(
prevInstWasReturn = true;
} break;
case SpvOpFunctionEnd: {
- // If there was an early return, create return label/block.
+ // If there was an early return, insert continue and return blocks.
// If previous instruction was return, insert branch instruction
// to return block.
if (returnLabelId != 0) {
if (prevInstWasReturn) AddBranch(returnLabelId, &new_blk_ptr);
new_blocks->push_back(std::move(new_blk_ptr));
+ new_blk_ptr.reset(new ir::BasicBlock(NewLabel(
+ singleTripLoopContinueId)));
+ AddBranchCond(GetFalseId(), singleTripLoopHeaderId, returnLabelId,
+ &new_blk_ptr);
+ new_blocks->push_back(std::move(new_blk_ptr));
new_blk_ptr.reset(new ir::BasicBlock(NewLabel(returnLabelId)));
multiBlocks = true;
}
@@ -298,7 +368,7 @@ void InlinePass::GenInlineCode(
}
// Copy remaining instructions from caller block.
auto cii = call_inst_itr;
- for (cii++; cii != call_block_itr->end(); cii++) {
+ for (++cii; cii != call_block_itr->end(); ++cii) {
std::unique_ptr<ir::Instruction> cp_inst(new ir::Instruction(*cii));
// If multiple blocks generated, regenerate any same-block
// instruction that has not been seen in this last block.
@@ -322,7 +392,7 @@ void InlinePass::GenInlineCode(
const auto mapItr = callee2caller.find(*iid);
if (mapItr != callee2caller.end()) {
*iid = mapItr->second;
- } else if (cpi->has_labels()) {
+ } else if (cpi->HasLabels()) {
const ir::Instruction* inst =
def_use_mgr_->id_to_defs().find(*iid)->second;
if (inst->opcode() == SpvOpLabel) {
@@ -347,16 +417,24 @@ void InlinePass::GenInlineCode(
});
// Update block map given replacement blocks.
for (auto& blk : *new_blocks) {
- id2block_[blk->label_id()] = &*blk;
+ id2block_[blk->id()] = &*blk;
}
}
+bool InlinePass::IsInlinableFunctionCall(const ir::Instruction* inst) {
+ if (inst->opcode() != SpvOp::SpvOpFunctionCall) return false;
+ const uint32_t calleeFnId =
+ inst->GetSingleWordOperand(kSpvFunctionCallFunctionId);
+ const auto ci = inlinable_.find(calleeFnId);
+ return ci != inlinable_.cend();
+}
+
bool InlinePass::Inline(ir::Function* func) {
bool modified = false;
// Using block iterators here because of block erasures and insertions.
- for (auto bi = func->begin(); bi != func->end(); bi++) {
+ for (auto bi = func->begin(); bi != func->end(); ++bi) {
for (auto ii = bi->begin(); ii != bi->end();) {
- if (ii->opcode() == SpvOp::SpvOpFunctionCall) {
+ if (IsInlinableFunctionCall(&*ii)) {
// Inline call.
std::vector<std::unique_ptr<ir::BasicBlock>> newBlocks;
std::vector<std::unique_ptr<ir::Instruction>> newVars;
@@ -366,10 +444,10 @@ bool InlinePass::Inline(ir::Function* func) {
if (newBlocks.size() > 1) {
const auto firstBlk = newBlocks.begin();
const auto lastBlk = newBlocks.end() - 1;
- const uint32_t firstId = (*firstBlk)->label_id();
- const uint32_t lastId = (*lastBlk)->label_id();
- (*lastBlk)
- ->ForEachSuccessorLabel([&firstId, &lastId, this](uint32_t succ) {
+ const uint32_t firstId = (*firstBlk)->id();
+ const uint32_t lastId = (*lastBlk)->id();
+ (*lastBlk)->ForEachSuccessorLabel(
+ [&firstId, &lastId, this](uint32_t succ) {
ir::BasicBlock* sbp = this->id2block_[succ];
sbp->ForEachPhiInst([&firstId, &lastId](ir::Instruction* phi) {
phi->ForEachInId([&firstId, &lastId](uint32_t* id) {
@@ -387,13 +465,134 @@ bool InlinePass::Inline(ir::Function* func) {
ii = bi->begin();
modified = true;
} else {
- ii++;
+ ++ii;
}
}
}
return modified;
}
+bool InlinePass::HasMultipleReturns(ir::Function* func) {
+ bool seenReturn = false;
+ bool multipleReturns = false;
+ for (auto& blk : *func) {
+ auto terminal_ii = blk.cend();
+ --terminal_ii;
+ if (terminal_ii->opcode() == SpvOpReturn ||
+ terminal_ii->opcode() == SpvOpReturnValue) {
+ if (seenReturn) {
+ multipleReturns = true;
+ break;
+ }
+ seenReturn = true;
+ }
+ }
+ return multipleReturns;
+}
+
+uint32_t InlinePass::MergeBlockIdIfAny(const ir::BasicBlock& blk) {
+ auto merge_ii = blk.cend();
+ --merge_ii;
+ uint32_t mbid = 0;
+ if (merge_ii != blk.cbegin()) {
+ --merge_ii;
+ if (merge_ii->opcode() == SpvOpLoopMerge)
+ mbid = merge_ii->GetSingleWordOperand(kSpvLoopMergeMergeBlockId);
+ else if (merge_ii->opcode() == SpvOpSelectionMerge)
+ mbid = merge_ii->GetSingleWordOperand(kSpvSelectionMergeMergeBlockId);
+ }
+ return mbid;
+}
+
+void InlinePass::ComputeStructuredSuccessors(ir::Function* func) {
+ // If header, make merge block first successor.
+ for (auto& blk : *func) {
+ uint32_t mbid = MergeBlockIdIfAny(blk);
+ if (mbid != 0)
+ block2structured_succs_[&blk].push_back(id2block_[mbid]);
+ // add true successors
+ blk.ForEachSuccessorLabel([&blk, this](uint32_t sbid) {
+ block2structured_succs_[&blk].push_back(id2block_[sbid]);
+ });
+ }
+}
+
+InlinePass::GetBlocksFunction InlinePass::StructuredSuccessorsFunction() {
+ return [this](const ir::BasicBlock* block) {
+ return &(block2structured_succs_[block]);
+ };
+}
+
+bool InlinePass::HasNoReturnInLoop(ir::Function* func) {
+ // If control not structured, do not do loop/return analysis
+ // TODO: Analyze returns in non-structured control flow
+ if (!module_->HasCapability(SpvCapabilityShader))
+ return false;
+ // Compute structured block order. This order has the property
+ // that dominators are before all blocks they dominate and merge blocks
+ // are after all blocks that are in the control constructs of their header.
+ ComputeStructuredSuccessors(func);
+ auto ignore_block = [](cbb_ptr) {};
+ auto ignore_edge = [](cbb_ptr, cbb_ptr) {};
+ std::list<const ir::BasicBlock*> structuredOrder;
+ spvtools::CFA<ir::BasicBlock>::DepthFirstTraversal(
+ &*func->begin(), StructuredSuccessorsFunction(), ignore_block,
+ [&](cbb_ptr b) { structuredOrder.push_front(b); }, ignore_edge);
+ // Search for returns in loops. Only need to track outermost loop
+ bool return_in_loop = false;
+ uint32_t outerLoopMergeId = 0;
+ for (auto& blk : structuredOrder) {
+ // Exiting current outer loop
+ if (blk->id() == outerLoopMergeId)
+ outerLoopMergeId = 0;
+ // Return block
+ auto terminal_ii = blk->cend();
+ --terminal_ii;
+ if (terminal_ii->opcode() == SpvOpReturn ||
+ terminal_ii->opcode() == SpvOpReturnValue) {
+ if (outerLoopMergeId != 0) {
+ return_in_loop = true;
+ break;
+ }
+ }
+ else if (terminal_ii != blk->cbegin()) {
+ auto merge_ii = terminal_ii;
+ --merge_ii;
+ // Entering outermost loop
+ if (merge_ii->opcode() == SpvOpLoopMerge && outerLoopMergeId == 0)
+ outerLoopMergeId = merge_ii->GetSingleWordOperand(
+ kSpvLoopMergeMergeBlockId);
+ }
+ }
+ return !return_in_loop;
+}
+
+void InlinePass::AnalyzeReturns(ir::Function* func) {
+ // Look for multiple returns
+ if (!HasMultipleReturns(func)) {
+ no_return_in_loop_.insert(func->result_id());
+ return;
+ }
+ early_return_.insert(func->result_id());
+ // If multiple returns, see if any are in a loop
+ if (HasNoReturnInLoop(func))
+ no_return_in_loop_.insert(func->result_id());
+}
+
+bool InlinePass::IsInlinableFunction(ir::Function* func) {
+ // We can only inline a function if it has blocks.
+ if (func->cbegin() == func->cend())
+ return false;
+ // Do not inline functions with returns in loops. Currently early return
+ // functions are inlined by wrapping them in a one trip loop and implementing
+ // the returns as a branch to the loop's merge block. However, this can only
+ // done validly if the return was not in a loop in the original function.
+ // Also remember functions with multiple (early) returns.
+ AnalyzeReturns(func);
+ const auto ci = no_return_in_loop_.find(func->result_id());
+ return ci != no_return_in_loop_.cend();
+}
+
void InlinePass::Initialize(ir::Module* module) {
def_use_mgr_.reset(new analysis::DefUseManager(consumer(), module));
@@ -403,14 +602,21 @@ void InlinePass::Initialize(ir::Module* module) {
// Save module.
module_ = module;
- // Initialize function and block maps.
+ false_id_ = 0;
+
id2function_.clear();
id2block_.clear();
+ block2structured_succs_.clear();
+ inlinable_.clear();
for (auto& fn : *module_) {
+ // Initialize function and block maps.
id2function_[fn.result_id()] = &fn;
for (auto& blk : fn) {
- id2block_[blk.label_id()] = &blk;
+ id2block_[blk.id()] = &blk;
}
+ // Compute inlinability
+ if (IsInlinableFunction(&fn))
+ inlinable_.insert(fn.result_id());
}
};
diff --git a/source/opt/inline_pass.h b/source/opt/inline_pass.h
index 523e1936..00d0e73f 100644
--- a/source/opt/inline_pass.h
+++ b/source/opt/inline_pass.h
@@ -21,6 +21,7 @@
#include <memory>
#include <unordered_map>
#include <vector>
+#include <list>
#include "def_use_manager.h"
#include "module.h"
@@ -31,7 +32,13 @@ namespace opt {
// See optimizer.hpp for documentation.
class InlinePass : public Pass {
+
+ using cbb_ptr = const ir::BasicBlock*;
+
public:
+ using GetBlocksFunction =
+ std::function<std::vector<ir::BasicBlock*>*(const ir::BasicBlock*)>;
+
InlinePass();
Status Process(ir::Module*) override;
@@ -56,6 +63,14 @@ class InlinePass : public Pass {
// Add unconditional branch to labelId to end of block block_ptr.
void AddBranch(uint32_t labelId, std::unique_ptr<ir::BasicBlock>* block_ptr);
+ // Add conditional branch to end of block |block_ptr|.
+ void AddBranchCond(uint32_t cond_id, uint32_t true_id,
+ uint32_t false_id, std::unique_ptr<ir::BasicBlock>* block_ptr);
+
+ // Add unconditional branch to labelId to end of block block_ptr.
+ void AddLoopMerge(uint32_t merge_id, uint32_t continue_id,
+ std::unique_ptr<ir::BasicBlock>* block_ptr);
+
// Add store of valId to ptrId to end of block block_ptr.
void AddStore(uint32_t ptrId, uint32_t valId,
std::unique_ptr<ir::BasicBlock>* block_ptr);
@@ -67,6 +82,10 @@ class InlinePass : public Pass {
// Return new label.
std::unique_ptr<ir::Instruction> NewLabel(uint32_t label_id);
+ // Returns the id for the boolean false value. Looks in the module first
+ // and creates it if not found. Remembers it for future calls.
+ uint32_t GetFalseId();
+
// Map callee params to caller args
void MapParams(ir::Function* calleeFn,
ir::UptrVectorIterator<ir::Instruction> call_inst_itr,
@@ -107,7 +126,7 @@ class InlinePass : public Pass {
// is returned. Formal parameters are trivially mapped to their actual
// parameters. Note that the first block in new_blocks retains the label
// of the original calling block. Also note that if an exit block is
- // created, it is the last block of new_blocks.
+ // created, it is the last block of new_blocks.
//
// Also return in new_vars additional OpVariable instructions required by
// and to be inserted into the caller function after the block at
@@ -117,6 +136,40 @@ class InlinePass : public Pass {
ir::UptrVectorIterator<ir::Instruction> call_inst_itr,
ir::UptrVectorIterator<ir::BasicBlock> call_block_itr);
+ // Return true if |inst| is a function call that can be inlined.
+ bool IsInlinableFunctionCall(const ir::Instruction* inst);
+
+ // Returns the id of the merge block declared by a merge instruction in
+ // this block, if any. If none, returns zero.
+ uint32_t MergeBlockIdIfAny(const ir::BasicBlock& blk);
+
+ // Compute structured successors for function |func|.
+ // A block's structured successors are the blocks it branches to
+ // together with its declared merge block if it has one.
+ // When order matters, the merge block always appears first.
+ // This assures correct depth first search in the presence of early
+ // returns and kills. If the successor vector contain duplicates
+ // if the merge block, they are safely ignored by DFS.
+ void ComputeStructuredSuccessors(ir::Function* func);
+
+ // Return function to return ordered structure successors for a given block
+ // Assumes ComputeStructuredSuccessors() has been called.
+ GetBlocksFunction StructuredSuccessorsFunction();
+
+ // Return true if |func| has multiple returns
+ bool HasMultipleReturns(ir::Function* func);
+
+ // Return true if |func| has no return in a loop. The current analysis
+ // requires structured control flow, so return false if control flow not
+ // structured ie. module is not a shader.
+ bool HasNoReturnInLoop(ir::Function* func);
+
+ // Find all functions with multiple returns and no returns in loops
+ void AnalyzeReturns(ir::Function* func);
+
+ // Return true if |func| is a function that can be inlined.
+ bool IsInlinableFunction(ir::Function* func);
+
// Exhaustively inline all function calls in func as well as in
// all code that is inlined into func. Return true if func is modified.
bool Inline(ir::Function* func);
@@ -133,6 +186,23 @@ class InlinePass : public Pass {
// Map from block's label id to block.
std::unordered_map<uint32_t, ir::BasicBlock*> id2block_;
+ // Set of ids of functions with early returns
+ std::set<uint32_t> early_return_;
+
+ // Set of ids of functions with no returns in loop
+ std::set<uint32_t> no_return_in_loop_;
+
+ // Set of ids of inlinable functions
+ std::set<uint32_t> inlinable_;
+
+ // Map from block to its structured successor blocks. See
+ // ComputeStructuredSuccessors() for definition.
+ std::unordered_map<const ir::BasicBlock*, std::vector<ir::BasicBlock*>>
+ block2structured_succs_;
+
+ // result id for OpConstantFalse
+ uint32_t false_id_;
+
// Next unused ID
uint32_t next_id_;
};
diff --git a/source/opt/instruction.h b/source/opt/instruction.h
index 2c3189cd..0143f7c9 100644
--- a/source/opt/instruction.h
+++ b/source/opt/instruction.h
@@ -23,7 +23,7 @@
#include "operand.h"
#include "spirv-tools/libspirv.h"
-#include "spirv/1.1/spirv.h"
+#include "spirv/1.2/spirv.h"
namespace spvtools {
namespace ir {
@@ -171,7 +171,7 @@ class Instruction {
inline void ForEachInId(const std::function<void(const uint32_t*)>& f) const;
// Returns true if any operands can be labels
- inline bool has_labels() const;
+ inline bool HasLabels() const;
// Pushes the binary segments for this instruction into the back of *|binary|.
void ToBinaryWithoutAttachedDebugInsts(std::vector<uint32_t>* binary) const;
@@ -257,7 +257,7 @@ inline void Instruction::ForEachInId(
if (opnd.type == SPV_OPERAND_TYPE_ID) f(&opnd.words[0]);
}
-inline bool Instruction::has_labels() const {
+inline bool Instruction::HasLabels() const {
switch (opcode_) {
case SpvOpSelectionMerge:
case SpvOpBranch:
diff --git a/source/opt/local_access_chain_convert_pass.cpp b/source/opt/local_access_chain_convert_pass.cpp
new file mode 100644
index 00000000..187494ff
--- /dev/null
+++ b/source/opt/local_access_chain_convert_pass.cpp
@@ -0,0 +1,369 @@
+// Copyright (c) 2017 The Khronos Group Inc.
+// Copyright (c) 2017 Valve Corporation
+// Copyright (c) 2017 LunarG Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iterator.h"
+#include "local_access_chain_convert_pass.h"
+
+static const int kSpvEntryPointFunctionId = 1;
+static const int kSpvStorePtrId = 0;
+static const int kSpvStoreValId = 1;
+static const int kSpvLoadPtrId = 0;
+static const int kSpvAccessChainPtrId = 0;
+static const int kSpvTypePointerStorageClass = 0;
+static const int kSpvTypePointerTypeId = 1;
+static const int kSpvConstantValue = 0;
+static const int kSpvTypeIntWidth = 0;
+
+namespace spvtools {
+namespace opt {
+
+bool LocalAccessChainConvertPass::IsNonPtrAccessChain(
+ const SpvOp opcode) const {
+ return opcode == SpvOpAccessChain || opcode == SpvOpInBoundsAccessChain;
+}
+
+bool LocalAccessChainConvertPass::IsMathType(
+ const ir::Instruction* typeInst) const {
+ switch (typeInst->opcode()) {
+ case SpvOpTypeInt:
+ case SpvOpTypeFloat:
+ case SpvOpTypeBool:
+ case SpvOpTypeVector:
+ case SpvOpTypeMatrix:
+ return true;
+ default:
+ break;
+ }
+ return false;
+}
+
+bool LocalAccessChainConvertPass::IsTargetType(
+ const ir::Instruction* typeInst) const {
+ if (IsMathType(typeInst))
+ return true;
+ if (typeInst->opcode() == SpvOpTypeArray)
+ return IsMathType(def_use_mgr_->GetDef(typeInst->GetSingleWordOperand(1)));
+ if (typeInst->opcode() != SpvOpTypeStruct)
+ return false;
+ // All struct members must be math type
+ int nonMathComp = 0;
+ typeInst->ForEachInId([&nonMathComp,this](const uint32_t* tid) {
+ ir::Instruction* compTypeInst = def_use_mgr_->GetDef(*tid);
+ if (!IsMathType(compTypeInst)) ++nonMathComp;
+ });
+ return nonMathComp == 0;
+}
+
+ir::Instruction* LocalAccessChainConvertPass::GetPtr(
+ ir::Instruction* ip,
+ uint32_t* varId) {
+ const uint32_t ptrId = ip->GetSingleWordInOperand(
+ ip->opcode() == SpvOpStore ? kSpvStorePtrId : kSpvLoadPtrId);
+ ir::Instruction* ptrInst = def_use_mgr_->GetDef(ptrId);
+ *varId = IsNonPtrAccessChain(ptrInst->opcode()) ?
+ ptrInst->GetSingleWordInOperand(kSpvAccessChainPtrId) :
+ ptrId;
+ return ptrInst;
+}
+
+bool LocalAccessChainConvertPass::IsTargetVar(uint32_t varId) {
+ if (seen_non_target_vars_.find(varId) != seen_non_target_vars_.end())
+ return false;
+ if (seen_target_vars_.find(varId) != seen_target_vars_.end())
+ return true;
+ const ir::Instruction* varInst = def_use_mgr_->GetDef(varId);
+ if (varInst->opcode() != SpvOpVariable)
+ return false;;
+ const uint32_t varTypeId = varInst->type_id();
+ const ir::Instruction* varTypeInst = def_use_mgr_->GetDef(varTypeId);
+ if (varTypeInst->GetSingleWordInOperand(kSpvTypePointerStorageClass) !=
+ SpvStorageClassFunction) {
+ seen_non_target_vars_.insert(varId);
+ return false;
+ }
+ const uint32_t varPteTypeId =
+ varTypeInst->GetSingleWordInOperand(kSpvTypePointerTypeId);
+ ir::Instruction* varPteTypeInst = def_use_mgr_->GetDef(varPteTypeId);
+ if (!IsTargetType(varPteTypeInst)) {
+ seen_non_target_vars_.insert(varId);
+ return false;
+ }
+ seen_target_vars_.insert(varId);
+ return true;
+}
+
+void LocalAccessChainConvertPass::DeleteIfUseless(ir::Instruction* inst) {
+ const uint32_t resId = inst->result_id();
+ assert(resId != 0);
+ analysis::UseList* uses = def_use_mgr_->GetUses(resId);
+ if (uses == nullptr)
+ def_use_mgr_->KillInst(inst);
+}
+
+void LocalAccessChainConvertPass::ReplaceAndDeleteLoad(
+ ir::Instruction* loadInst,
+ uint32_t replId,
+ ir::Instruction* ptrInst) {
+ const uint32_t loadId = loadInst->result_id();
+ (void) def_use_mgr_->ReplaceAllUsesWith(loadId, replId);
+ // remove load instruction
+ def_use_mgr_->KillInst(loadInst);
+ // if access chain, see if it can be removed as well
+ if (IsNonPtrAccessChain(ptrInst->opcode())) {
+ DeleteIfUseless(ptrInst);
+ }
+}
+
+uint32_t LocalAccessChainConvertPass::GetPointeeTypeId(
+ const ir::Instruction* ptrInst) const {
+ const uint32_t ptrTypeId = ptrInst->type_id();
+ const ir::Instruction* ptrTypeInst = def_use_mgr_->GetDef(ptrTypeId);
+ return ptrTypeInst->GetSingleWordInOperand(kSpvTypePointerTypeId);
+}
+
+void LocalAccessChainConvertPass::BuildAndAppendInst(
+ SpvOp opcode,
+ uint32_t typeId,
+ uint32_t resultId,
+ const std::vector<ir::Operand>& in_opnds,
+ std::vector<std::unique_ptr<ir::Instruction>>* newInsts) {
+ std::unique_ptr<ir::Instruction> newInst(new ir::Instruction(
+ opcode, typeId, resultId, in_opnds));
+ def_use_mgr_->AnalyzeInstDefUse(&*newInst);
+ newInsts->emplace_back(std::move(newInst));
+}
+
+uint32_t LocalAccessChainConvertPass::BuildAndAppendVarLoad(
+ const ir::Instruction* ptrInst,
+ uint32_t* varId,
+ uint32_t* varPteTypeId,
+ std::vector<std::unique_ptr<ir::Instruction>>* newInsts) {
+ const uint32_t ldResultId = TakeNextId();
+ *varId = ptrInst->GetSingleWordInOperand(kSpvAccessChainPtrId);
+ const ir::Instruction* varInst = def_use_mgr_->GetDef(*varId);
+ assert(varInst->opcode() == SpvOpVariable);
+ *varPteTypeId = GetPointeeTypeId(varInst);
+ BuildAndAppendInst(SpvOpLoad, *varPteTypeId, ldResultId,
+ {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {*varId}}}, newInsts);
+ return ldResultId;
+}
+
+void LocalAccessChainConvertPass::AppendConstantOperands(
+ const ir::Instruction* ptrInst,
+ std::vector<ir::Operand>* in_opnds) {
+ uint32_t iidIdx = 0;
+ ptrInst->ForEachInId([&iidIdx, &in_opnds, this](const uint32_t *iid) {
+ if (iidIdx > 0) {
+ const ir::Instruction* cInst = def_use_mgr_->GetDef(*iid);
+ uint32_t val = cInst->GetSingleWordInOperand(kSpvConstantValue);
+ in_opnds->push_back(
+ {spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER, {val}});
+ }
+ ++iidIdx;
+ });
+}
+
+uint32_t LocalAccessChainConvertPass::GenAccessChainLoadReplacement(
+ const ir::Instruction* ptrInst,
+ std::vector<std::unique_ptr<ir::Instruction>>* newInsts) {
+
+ // Build and append load of variable in ptrInst
+ uint32_t varId;
+ uint32_t varPteTypeId;
+ const uint32_t ldResultId = BuildAndAppendVarLoad(ptrInst, &varId,
+ &varPteTypeId, newInsts);
+
+ // Build and append Extract
+ const uint32_t extResultId = TakeNextId();
+ const uint32_t ptrPteTypeId = GetPointeeTypeId(ptrInst);
+ std::vector<ir::Operand> ext_in_opnds =
+ {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {ldResultId}}};
+ AppendConstantOperands(ptrInst, &ext_in_opnds);
+ BuildAndAppendInst(SpvOpCompositeExtract, ptrPteTypeId, extResultId,
+ ext_in_opnds, newInsts);
+ return extResultId;
+}
+
+void LocalAccessChainConvertPass::GenAccessChainStoreReplacement(
+ const ir::Instruction* ptrInst,
+ uint32_t valId,
+ std::vector<std::unique_ptr<ir::Instruction>>* newInsts) {
+
+ // Build and append load of variable in ptrInst
+ uint32_t varId;
+ uint32_t varPteTypeId;
+ const uint32_t ldResultId = BuildAndAppendVarLoad(ptrInst, &varId,
+ &varPteTypeId, newInsts);
+
+ // Build and append Insert
+ const uint32_t insResultId = TakeNextId();
+ std::vector<ir::Operand> ins_in_opnds =
+ {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {valId}},
+ {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {ldResultId}}};
+ AppendConstantOperands(ptrInst, &ins_in_opnds);
+ BuildAndAppendInst(
+ SpvOpCompositeInsert, varPteTypeId, insResultId, ins_in_opnds, newInsts);
+
+ // Build and append Store
+ BuildAndAppendInst(SpvOpStore, 0, 0,
+ {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {varId}},
+ {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {insResultId}}},
+ newInsts);
+}
+
+bool LocalAccessChainConvertPass::IsConstantIndexAccessChain(
+ const ir::Instruction* acp) const {
+ uint32_t inIdx = 0;
+ uint32_t nonConstCnt = 0;
+ acp->ForEachInId([&inIdx, &nonConstCnt, this](const uint32_t* tid) {
+ if (inIdx > 0) {
+ ir::Instruction* opInst = def_use_mgr_->GetDef(*tid);
+ if (opInst->opcode() != SpvOpConstant) ++nonConstCnt;
+ }
+ ++inIdx;
+ });
+ return nonConstCnt == 0;
+}
+
+void LocalAccessChainConvertPass::FindTargetVars(ir::Function* func) {
+ for (auto bi = func->begin(); bi != func->end(); ++bi) {
+ for (auto ii = bi->begin(); ii != bi->end(); ++ii) {
+ switch (ii->opcode()) {
+ case SpvOpStore:
+ case SpvOpLoad: {
+ uint32_t varId;
+ ir::Instruction* ptrInst = GetPtr(&*ii, &varId);
+ // For now, only convert non-ptr access chains
+ if (!IsNonPtrAccessChain(ptrInst->opcode()))
+ break;
+ // For now, only convert non-nested access chains
+ // TODO(): Convert nested access chains
+ if (!IsTargetVar(varId))
+ break;
+ // Rule out variables accessed with non-constant indices
+ if (!IsConstantIndexAccessChain(ptrInst)) {
+ seen_non_target_vars_.insert(varId);
+ seen_target_vars_.erase(varId);
+ break;
+ }
+ } break;
+ default:
+ break;
+ }
+ }
+ }
+}
+
+bool LocalAccessChainConvertPass::ConvertLocalAccessChains(ir::Function* func) {
+ FindTargetVars(func);
+ // Replace access chains of all targeted variables with equivalent
+ // extract and insert sequences
+ bool modified = false;
+ for (auto bi = func->begin(); bi != func->end(); ++bi) {
+ for (auto ii = bi->begin(); ii != bi->end(); ++ii) {
+ switch (ii->opcode()) {
+ case SpvOpLoad: {
+ uint32_t varId;
+ ir::Instruction* ptrInst = GetPtr(&*ii, &varId);
+ if (!IsNonPtrAccessChain(ptrInst->opcode()))
+ break;
+ if (!IsTargetVar(varId))
+ break;
+ std::vector<std::unique_ptr<ir::Instruction>> newInsts;
+ uint32_t replId =
+ GenAccessChainLoadReplacement(ptrInst, &newInsts);
+ ReplaceAndDeleteLoad(&*ii, replId, ptrInst);
+ ++ii;
+ ii = ii.InsertBefore(&newInsts);
+ ++ii;
+ modified = true;
+ } break;
+ case SpvOpStore: {
+ uint32_t varId;
+ ir::Instruction* ptrInst = GetPtr(&*ii, &varId);
+ if (!IsNonPtrAccessChain(ptrInst->opcode()))
+ break;
+ if (!IsTargetVar(varId))
+ break;
+ std::vector<std::unique_ptr<ir::Instruction>> newInsts;
+ uint32_t valId = ii->GetSingleWordInOperand(kSpvStoreValId);
+ GenAccessChainStoreReplacement(ptrInst, valId, &newInsts);
+ def_use_mgr_->KillInst(&*ii);
+ DeleteIfUseless(ptrInst);
+ ++ii;
+ ii = ii.InsertBefore(&newInsts);
+ ++ii;
+ ++ii;
+ modified = true;
+ } break;
+ default:
+ break;
+ }
+ }
+ }
+ return modified;
+}
+
+void LocalAccessChainConvertPass::Initialize(ir::Module* module) {
+
+ module_ = module;
+
+ // Initialize function and block maps
+ id2function_.clear();
+ for (auto& fn : *module_)
+ id2function_[fn.result_id()] = &fn;
+
+ // Initialize Target Variable Caches
+ seen_target_vars_.clear();
+ seen_non_target_vars_.clear();
+
+ def_use_mgr_.reset(new analysis::DefUseManager(consumer(), module_));
+
+ // Initialize next unused Id.
+ next_id_ = module->id_bound();
+};
+
+Pass::Status LocalAccessChainConvertPass::ProcessImpl() {
+ // If non-32-bit integer type in module, terminate processing
+ // TODO(): Handle non-32-bit integer constants in access chains
+ for (const ir::Instruction& inst : module_->types_values())
+ if (inst.opcode() == SpvOpTypeInt &&
+ inst.GetSingleWordInOperand(kSpvTypeIntWidth) != 32)
+ return Status::SuccessWithoutChange;
+ // Process all entry point functions.
+ bool modified = false;
+ for (auto& e : module_->entry_points()) {
+ ir::Function* fn =
+ id2function_[e.GetSingleWordOperand(kSpvEntryPointFunctionId)];
+ modified = modified || ConvertLocalAccessChains(fn);
+ }
+
+ FinalizeNextId(module_);
+
+ return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
+}
+
+LocalAccessChainConvertPass::LocalAccessChainConvertPass()
+ : module_(nullptr), def_use_mgr_(nullptr), next_id_(0) {}
+
+Pass::Status LocalAccessChainConvertPass::Process(ir::Module* module) {
+ Initialize(module);
+ return ProcessImpl();
+}
+
+} // namespace opt
+} // namespace spvtools
+
diff --git a/source/opt/local_access_chain_convert_pass.h b/source/opt/local_access_chain_convert_pass.h
new file mode 100644
index 00000000..3a2d6054
--- /dev/null
+++ b/source/opt/local_access_chain_convert_pass.h
@@ -0,0 +1,167 @@
+// Copyright (c) 2017 The Khronos Group Inc.
+// Copyright (c) 2017 Valve Corporation
+// Copyright (c) 2017 LunarG Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef LIBSPIRV_OPT_LOCAL_ACCESS_CHAIN_CONVERT_PASS_H_
+#define LIBSPIRV_OPT_LOCAL_ACCESS_CHAIN_CONVERT_PASS_H_
+
+
+#include <algorithm>
+#include <map>
+#include <queue>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+
+#include "basic_block.h"
+#include "def_use_manager.h"
+#include "module.h"
+#include "pass.h"
+
+namespace spvtools {
+namespace opt {
+
+// See optimizer.hpp for documentation.
+class LocalAccessChainConvertPass : public Pass {
+ public:
+ LocalAccessChainConvertPass();
+ const char* name() const override { return "convert-local-access-chains"; }
+ Status Process(ir::Module*) override;
+
+ private:
+ // Returns true if |opcode| is a non-pointer access chain op
+ // TODO(): Support conversion of pointer access chains.
+ bool IsNonPtrAccessChain(const SpvOp opcode) const;
+
+ // Returns true if |typeInst| is a scalar type
+ // or a vector or matrix
+ bool IsMathType(const ir::Instruction* typeInst) const;
+
+ // Returns true if |typeInst| is a math type or a struct or array
+ // of a math type.
+ // TODO(): Add more complex types to convert
+ bool IsTargetType(const ir::Instruction* typeInst) const;
+
+ // Given a load or store |ip|, return the pointer instruction.
+ // If the pointer is an access chain, |*varId| is its base id.
+ // Otherwise it is the id of the pointer of the load/store.
+ ir::Instruction* GetPtr(ir::Instruction* ip, uint32_t* varId);
+
+ // Search |func| and cache function scope variables of target type that are
+ // not accessed with non-constant-index access chains. Also cache non-target
+ // variables.
+ void FindTargetVars(ir::Function* func);
+
+ // Return true if |varId| is a previously identified target variable.
+ // Return false if |varId| is a previously identified non-target variable.
+ // See FindTargetVars() for definition of target variable. If variable is
+ // not cached, return true if variable is a function scope variable of
+ // target type, false otherwise. Updates caches of target and non-target
+ // variables.
+ bool IsTargetVar(uint32_t varId);
+
+ // Delete |inst| if it has no uses. Assumes |inst| has a non-zero resultId.
+ void DeleteIfUseless(ir::Instruction* inst);
+
+ // Replace all instances of |loadInst|'s id with |replId| and delete
+ // |loadInst| and its pointer |ptrInst| if it is a useless access chain.
+ void ReplaceAndDeleteLoad(ir::Instruction* loadInst,
+ uint32_t replId,
+ ir::Instruction* ptrInst);
+
+ // Return type id for |ptrInst|'s pointee
+ uint32_t GetPointeeTypeId(const ir::Instruction* ptrInst) const;
+
+ // Build instruction from |opcode|, |typeId|, |resultId|, and |in_opnds|.
+ // Append to |newInsts|.
+ void BuildAndAppendInst(SpvOp opcode, uint32_t typeId, uint32_t resultId,
+ const std::vector<ir::Operand>& in_opnds,
+ std::vector<std::unique_ptr<ir::Instruction>>* newInsts);
+
+ // Build load of variable in |ptrInst| and append to |newInsts|.
+ // Return var in |varId| and its pointee type in |varPteTypeId|.
+ uint32_t BuildAndAppendVarLoad(const ir::Instruction* ptrInst,
+ uint32_t* varId, uint32_t* varPteTypeId,
+ std::vector<std::unique_ptr<ir::Instruction>>* newInsts);
+
+ // Append literal integer operands to |in_opnds| corresponding to constant
+ // integer operands from access chain |ptrInst|. Assumes all indices in
+ // access chains are OpConstant.
+ void AppendConstantOperands( const ir::Instruction* ptrInst,
+ std::vector<ir::Operand>* in_opnds);
+
+ // Create a load/insert/store equivalent to a store of
+ // |valId| through (constant index) access chaing |ptrInst|.
+ // Append to |newInsts|.
+ void GenAccessChainStoreReplacement(const ir::Instruction* ptrInst,
+ uint32_t valId,
+ std::vector<std::unique_ptr<ir::Instruction>>* newInsts);
+
+ // For the (constant index) access chain |ptrInst|, create an
+ // equivalent load and extract. Append to |newInsts|.
+ uint32_t GenAccessChainLoadReplacement(const ir::Instruction* ptrInst,
+ std::vector<std::unique_ptr<ir::Instruction>>* newInsts);
+
+ // Return true if all indices of access chain |acp| are OpConstant integers
+ bool IsConstantIndexAccessChain(const ir::Instruction* acp) const;
+
+ // Identify all function scope variables of target type which are
+ // accessed only with loads, stores and access chains with constant
+ // indices. Convert all loads and stores of such variables into equivalent
+ // loads, stores, extracts and inserts. This unifies access to these
+ // variables to a single mode and simplifies analysis and optimization.
+ // See IsTargetType() for targeted types.
+ //
+ // Nested access chains and pointer access chains are not currently
+ // converted.
+ bool ConvertLocalAccessChains(ir::Function* func);
+
+ // Save next available id into |module|.
+ inline void FinalizeNextId(ir::Module* module) {
+ module->SetIdBound(next_id_);
+ }
+
+ // Return next available id and calculate next.
+ inline uint32_t TakeNextId() {
+ return next_id_++;
+ }
+
+ void Initialize(ir::Module* module);
+ Pass::Status ProcessImpl();
+
+ // Module this pass is processing
+ ir::Module* module_;
+
+ // Def-Uses for the module we are processing
+ std::unique_ptr<analysis::DefUseManager> def_use_mgr_;
+
+ // Map from function's result id to function
+ std::unordered_map<uint32_t, ir::Function*> id2function_;
+
+ // Cache of verified target vars
+ std::unordered_set<uint32_t> seen_target_vars_;
+
+ // Cache of verified non-target vars
+ std::unordered_set<uint32_t> seen_non_target_vars_;
+
+ // Next unused ID
+ uint32_t next_id_;
+};
+
+} // namespace opt
+} // namespace spvtools
+
+#endif // LIBSPIRV_OPT_LOCAL_ACCESS_CHAIN_CONVERT_PASS_H_
+
diff --git a/source/opt/local_single_block_elim_pass.cpp b/source/opt/local_single_block_elim_pass.cpp
new file mode 100644
index 00000000..b18b08d0
--- /dev/null
+++ b/source/opt/local_single_block_elim_pass.cpp
@@ -0,0 +1,344 @@
+// Copyright (c) 2017 The Khronos Group Inc.
+// Copyright (c) 2017 Valve Corporation
+// Copyright (c) 2017 LunarG Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iterator.h"
+#include "local_single_block_elim_pass.h"
+
+static const int kSpvEntryPointFunctionId = 1;
+static const int kSpvStorePtrId = 0;
+static const int kSpvStoreValId = 1;
+static const int kSpvLoadPtrId = 0;
+static const int kSpvAccessChainPtrId = 0;
+static const int kSpvTypePointerStorageClass = 0;
+static const int kSpvTypePointerTypeId = 1;
+
+namespace spvtools {
+namespace opt {
+
+bool LocalSingleBlockLoadStoreElimPass::IsNonPtrAccessChain(
+ const SpvOp opcode) const {
+ return opcode == SpvOpAccessChain || opcode == SpvOpInBoundsAccessChain;
+}
+
+bool LocalSingleBlockLoadStoreElimPass::IsMathType(
+ const ir::Instruction* typeInst) const {
+ switch (typeInst->opcode()) {
+ case SpvOpTypeInt:
+ case SpvOpTypeFloat:
+ case SpvOpTypeBool:
+ case SpvOpTypeVector:
+ case SpvOpTypeMatrix:
+ return true;
+ default:
+ break;
+ }
+ return false;
+}
+
+bool LocalSingleBlockLoadStoreElimPass::IsTargetType(
+ const ir::Instruction* typeInst) const {
+ if (IsMathType(typeInst))
+ return true;
+ if (typeInst->opcode() == SpvOpTypeArray)
+ return IsMathType(def_use_mgr_->GetDef(typeInst->GetSingleWordOperand(1)));
+ if (typeInst->opcode() != SpvOpTypeStruct)
+ return false;
+ // All struct members must be math type
+ int nonMathComp = 0;
+ typeInst->ForEachInId([&nonMathComp,this](const uint32_t* tid) {
+ ir::Instruction* compTypeInst = def_use_mgr_->GetDef(*tid);
+ if (!IsMathType(compTypeInst)) ++nonMathComp;
+ });
+ return nonMathComp == 0;
+}
+
+ir::Instruction* LocalSingleBlockLoadStoreElimPass::GetPtr(
+ ir::Instruction* ip, uint32_t* varId) {
+ *varId = ip->GetSingleWordInOperand(
+ ip->opcode() == SpvOpStore ? kSpvStorePtrId : kSpvLoadPtrId);
+ ir::Instruction* ptrInst = def_use_mgr_->GetDef(*varId);
+ ir::Instruction* varInst = ptrInst;
+ while (IsNonPtrAccessChain(varInst->opcode())) {
+ *varId = varInst->GetSingleWordInOperand(kSpvAccessChainPtrId);
+ varInst = def_use_mgr_->GetDef(*varId);
+ }
+ return ptrInst;
+}
+
+bool LocalSingleBlockLoadStoreElimPass::IsTargetVar(uint32_t varId) {
+ if (seen_non_target_vars_.find(varId) != seen_non_target_vars_.end())
+ return false;
+ if (seen_target_vars_.find(varId) != seen_target_vars_.end())
+ return true;
+ const ir::Instruction* varInst = def_use_mgr_->GetDef(varId);
+ assert(varInst->opcode() == SpvOpVariable);
+ const uint32_t varTypeId = varInst->type_id();
+ const ir::Instruction* varTypeInst = def_use_mgr_->GetDef(varTypeId);
+ if (varTypeInst->GetSingleWordInOperand(kSpvTypePointerStorageClass) !=
+ SpvStorageClassFunction) {
+ seen_non_target_vars_.insert(varId);
+ return false;
+ }
+ const uint32_t varPteTypeId =
+ varTypeInst->GetSingleWordInOperand(kSpvTypePointerTypeId);
+ ir::Instruction* varPteTypeInst = def_use_mgr_->GetDef(varPteTypeId);
+ if (!IsTargetType(varPteTypeInst)) {
+ seen_non_target_vars_.insert(varId);
+ return false;
+ }
+ seen_target_vars_.insert(varId);
+ return true;
+}
+
+void LocalSingleBlockLoadStoreElimPass::ReplaceAndDeleteLoad(
+ ir::Instruction* loadInst, uint32_t replId) {
+ const uint32_t loadId = loadInst->result_id();
+ (void) def_use_mgr_->ReplaceAllUsesWith(loadId, replId);
+ // TODO(greg-lunarg): Consider moving DCE into separate pass
+ DCEInst(loadInst);
+}
+
+bool LocalSingleBlockLoadStoreElimPass::HasLoads(uint32_t ptrId) const {
+ analysis::UseList* uses = def_use_mgr_->GetUses(ptrId);
+ if (uses == nullptr)
+ return false;
+ for (auto u : *uses) {
+ SpvOp op = u.inst->opcode();
+ if (IsNonPtrAccessChain(op)) {
+ if (HasLoads(u.inst->result_id()))
+ return true;
+ }
+ else {
+ // Conservatively assume that calls will do a load
+ // TODO(): Improve analysis around function calls
+ if (op == SpvOpLoad || op == SpvOpFunctionCall)
+ return true;
+ }
+ }
+ return false;
+}
+
+bool LocalSingleBlockLoadStoreElimPass::IsLiveVar(uint32_t varId) const {
+ // non-function scope vars are live
+ const ir::Instruction* varInst = def_use_mgr_->GetDef(varId);
+ assert(varInst->opcode() == SpvOpVariable);
+ const uint32_t varTypeId = varInst->type_id();
+ const ir::Instruction* varTypeInst = def_use_mgr_->GetDef(varTypeId);
+ if (varTypeInst->GetSingleWordInOperand(kSpvTypePointerStorageClass) !=
+ SpvStorageClassFunction)
+ return true;
+ // test if variable is loaded from
+ return HasLoads(varId);
+}
+
+bool LocalSingleBlockLoadStoreElimPass::IsLiveStore(
+ ir::Instruction* storeInst) {
+ // get store's variable
+ uint32_t varId;
+ (void) GetPtr(storeInst, &varId);
+ return IsLiveVar(varId);
+}
+
+void LocalSingleBlockLoadStoreElimPass::AddStores(
+ uint32_t ptr_id, std::queue<ir::Instruction*>* insts) {
+ analysis::UseList* uses = def_use_mgr_->GetUses(ptr_id);
+ if (uses != nullptr) {
+ for (auto u : *uses) {
+ if (IsNonPtrAccessChain(u.inst->opcode()))
+ AddStores(u.inst->result_id(), insts);
+ else if (u.inst->opcode() == SpvOpStore)
+ insts->push(u.inst);
+ }
+ }
+}
+
+void LocalSingleBlockLoadStoreElimPass::DCEInst(ir::Instruction* inst) {
+ std::queue<ir::Instruction*> deadInsts;
+ deadInsts.push(inst);
+ while (!deadInsts.empty()) {
+ ir::Instruction* di = deadInsts.front();
+ // Don't delete labels
+ if (di->opcode() == SpvOpLabel) {
+ deadInsts.pop();
+ continue;
+ }
+ // Remember operands
+ std::vector<uint32_t> ids;
+ di->ForEachInId([&ids](uint32_t* iid) {
+ ids.push_back(*iid);
+ });
+ uint32_t varId = 0;
+ // Remember variable if dead load
+ if (di->opcode() == SpvOpLoad)
+ (void) GetPtr(di, &varId);
+ def_use_mgr_->KillInst(di);
+ // For all operands with no remaining uses, add their instruction
+ // to the dead instruction queue.
+ for (auto id : ids) {
+ analysis::UseList* uses = def_use_mgr_->GetUses(id);
+ if (uses == nullptr)
+ deadInsts.push(def_use_mgr_->GetDef(id));
+ }
+ // if a load was deleted and it was the variable's
+ // last load, add all its stores to dead queue
+ if (varId != 0 && !IsLiveVar(varId))
+ AddStores(varId, &deadInsts);
+ deadInsts.pop();
+ }
+}
+
+bool LocalSingleBlockLoadStoreElimPass::LocalSingleBlockLoadStoreElim(
+ ir::Function* func) {
+ // Verify no CopyObject ops in function. This is a pre-SSA pass and
+ // is generally not useful for code already in CSSA form.
+ for (auto& blk : *func)
+ for (auto& inst : blk)
+ if (inst.opcode() == SpvOpCopyObject)
+ return false;
+ // Perform local store/load and load/load elimination on each block
+ bool modified = false;
+ for (auto bi = func->begin(); bi != func->end(); ++bi) {
+ var2store_.clear();
+ var2load_.clear();
+ pinned_vars_.clear();
+ for (auto ii = bi->begin(); ii != bi->end(); ++ii) {
+ switch (ii->opcode()) {
+ case SpvOpStore: {
+ // Verify store variable is target type
+ uint32_t varId;
+ ir::Instruction* ptrInst = GetPtr(&*ii, &varId);
+ if (!IsTargetVar(varId))
+ continue;
+ // Register the store
+ if (ptrInst->opcode() == SpvOpVariable) {
+ // if not pinned, look for WAW
+ if (pinned_vars_.find(varId) == pinned_vars_.end()) {
+ auto si = var2store_.find(varId);
+ if (si != var2store_.end()) {
+ def_use_mgr_->KillInst(si->second);
+ }
+ }
+ var2store_[varId] = &*ii;
+ }
+ else {
+ assert(IsNonPtrAccessChain(ptrInst->opcode()));
+ var2store_.erase(varId);
+ }
+ pinned_vars_.erase(varId);
+ var2load_.erase(varId);
+ } break;
+ case SpvOpLoad: {
+ // Verify store variable is target type
+ uint32_t varId;
+ ir::Instruction* ptrInst = GetPtr(&*ii, &varId);
+ if (!IsTargetVar(varId))
+ continue;
+ // Look for previous store or load
+ uint32_t replId = 0;
+ if (ptrInst->opcode() == SpvOpVariable) {
+ auto si = var2store_.find(varId);
+ if (si != var2store_.end()) {
+ replId = si->second->GetSingleWordInOperand(kSpvStoreValId);
+ }
+ else {
+ auto li = var2load_.find(varId);
+ if (li != var2load_.end()) {
+ replId = li->second->result_id();
+ }
+ }
+ }
+ if (replId != 0) {
+ // replace load's result id and delete load
+ ReplaceAndDeleteLoad(&*ii, replId);
+ modified = true;
+ }
+ else {
+ if (ptrInst->opcode() == SpvOpVariable)
+ var2load_[varId] = &*ii; // register load
+ pinned_vars_.insert(varId);
+ }
+ } break;
+ case SpvOpFunctionCall: {
+ // Conservatively assume all locals are redefined for now.
+ // TODO(): Handle more optimally
+ var2store_.clear();
+ var2load_.clear();
+ pinned_vars_.clear();
+ } break;
+ default:
+ break;
+ }
+ }
+ // Go back and delete useless stores in block
+ // TODO(greg-lunarg): Consider moving DCE into separate pass
+ for (auto ii = bi->begin(); ii != bi->end(); ++ii) {
+ if (ii->opcode() != SpvOpStore)
+ continue;
+ if (IsLiveStore(&*ii))
+ continue;
+ DCEInst(&*ii);
+ }
+ }
+ return modified;
+}
+
+void LocalSingleBlockLoadStoreElimPass::Initialize(ir::Module* module) {
+
+ module_ = module;
+
+ // Initialize function and block maps
+ id2function_.clear();
+ for (auto& fn : *module_)
+ id2function_[fn.result_id()] = &fn;
+
+ // Initialize Target Type Caches
+ seen_target_vars_.clear();
+ seen_non_target_vars_.clear();
+
+ // TODO(): Reuse def/use from previous passes
+ def_use_mgr_.reset(new analysis::DefUseManager(consumer(), module_));
+
+ // Start new ids with next availablein module
+ next_id_ = module_->id_bound();
+
+};
+
+Pass::Status LocalSingleBlockLoadStoreElimPass::ProcessImpl() {
+ // Assumes logical addressing only
+ if (module_->HasCapability(SpvCapabilityAddresses))
+ return Status::SuccessWithoutChange;
+ bool modified = false;
+ // Call Mem2Reg on all remaining functions.
+ for (auto& e : module_->entry_points()) {
+ ir::Function* fn =
+ id2function_[e.GetSingleWordOperand(kSpvEntryPointFunctionId)];
+ modified = modified || LocalSingleBlockLoadStoreElim(fn);
+ }
+ FinalizeNextId(module_);
+ return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
+}
+
+LocalSingleBlockLoadStoreElimPass::LocalSingleBlockLoadStoreElimPass()
+ : module_(nullptr), def_use_mgr_(nullptr), next_id_(0) {}
+
+Pass::Status LocalSingleBlockLoadStoreElimPass::Process(ir::Module* module) {
+ Initialize(module);
+ return ProcessImpl();
+}
+
+} // namespace opt
+} // namespace spvtools
+
diff --git a/source/opt/local_single_block_elim_pass.h b/source/opt/local_single_block_elim_pass.h
new file mode 100644
index 00000000..b5a14f42
--- /dev/null
+++ b/source/opt/local_single_block_elim_pass.h
@@ -0,0 +1,153 @@
+// Copyright (c) 2017 The Khronos Group Inc.
+// Copyright (c) 2017 Valve Corporation
+// Copyright (c) 2017 LunarG Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef LIBSPIRV_OPT_LOCAL_SINGLE_BLOCK_ELIM_PASS_H_
+#define LIBSPIRV_OPT_LOCAL_SINGLE_BLOCK_ELIM_PASS_H_
+
+
+#include <algorithm>
+#include <map>
+#include <queue>
+#include <utility>
+#include <unordered_map>
+#include <unordered_set>
+
+#include "basic_block.h"
+#include "def_use_manager.h"
+#include "module.h"
+#include "pass.h"
+
+namespace spvtools {
+namespace opt {
+
+// See optimizer.hpp for documentation.
+class LocalSingleBlockLoadStoreElimPass : public Pass {
+ public:
+ LocalSingleBlockLoadStoreElimPass();
+ const char* name() const override { return "eliminate-local-single-block"; }
+ Status Process(ir::Module*) override;
+
+ private:
+ // Returns true if |opcode| is a non-ptr access chain op
+ bool IsNonPtrAccessChain(const SpvOp opcode) const;
+
+ // Returns true if |typeInst| is a scalar type
+ // or a vector or matrix
+ bool IsMathType(const ir::Instruction* typeInst) const;
+
+ // Returns true if |typeInst| is a math type or a struct or array
+ // of a math type.
+ bool IsTargetType(const ir::Instruction* typeInst) const;
+
+ // Given a load or store |ip|, return the pointer instruction.
+ // Also return the base variable's id in |varId|.
+ ir::Instruction* GetPtr(ir::Instruction* ip, uint32_t* varId);
+
+ // Return true if |varId| is a previously identified target variable.
+ // Return false if |varId| is a previously identified non-target variable.
+ // If variable is not cached, return true if variable is a function scope
+ // variable of target type, false otherwise. Updates caches of target
+ // and non-target variables.
+ bool IsTargetVar(uint32_t varId);
+
+ // Replace all instances of |loadInst|'s id with |replId| and delete
+ // |loadInst|.
+ void ReplaceAndDeleteLoad(ir::Instruction* loadInst, uint32_t replId);
+
+ // Return true if any instruction loads from |ptrId|
+ bool HasLoads(uint32_t ptrId) const;
+
+ // Return true if |varId| is not a function variable or if it has
+ // a load
+ bool IsLiveVar(uint32_t varId) const;
+
+ // Return true if |storeInst| is not to function variable or if its
+ // base variable has a load
+ bool IsLiveStore(ir::Instruction* storeInst);
+
+ // Add stores using |ptr_id| to |insts|
+ void AddStores(uint32_t ptr_id, std::queue<ir::Instruction*>* insts);
+
+ // Delete |inst| and iterate DCE on all its operands. Won't delete
+ // labels.
+ void DCEInst(ir::Instruction* inst);
+
+ // On all entry point functions, within each basic block, eliminate
+ // loads and stores to function variables where possible. For
+ // loads, if previous load or store to same variable, replace
+ // load id with previous id and delete load. Finally, check if
+ // remaining stores are useless, and delete store and variable
+ // where possible. Assumes logical addressing.
+ bool LocalSingleBlockLoadStoreElim(ir::Function* func);
+
+ // Save next available id into |module|.
+ inline void FinalizeNextId(ir::Module* module) {
+ module->SetIdBound(next_id_);
+ }
+
+ // Return next available id and calculate next.
+ inline uint32_t TakeNextId() {
+ return next_id_++;
+ }
+
+ void Initialize(ir::Module* module);
+ Pass::Status ProcessImpl();
+
+ // Module this pass is processing
+ ir::Module* module_;
+
+ // Def-Uses for the module we are processing
+ std::unique_ptr<analysis::DefUseManager> def_use_mgr_;
+
+ // Map from function's result id to function
+ std::unordered_map<uint32_t, ir::Function*> id2function_;
+
+ // Cache of previously seen target types
+ std::unordered_set<uint32_t> seen_target_vars_;
+
+ // Cache of previously seen non-target types
+ std::unordered_set<uint32_t> seen_non_target_vars_;
+
+ // Map from function scope variable to a store of that variable in the
+ // current block whose value is currently valid. This map is cleared
+ // at the start of each block and incrementally updated as the block
+ // is scanned. The stores are candidates for elimination. The map is
+ // conservatively cleared when a function call is encountered.
+ std::unordered_map<uint32_t, ir::Instruction*> var2store_;
+
+ // Map from function scope variable to a load of that variable in the
+ // current block whose value is currently valid. This map is cleared
+ // at the start of each block and incrementally updated as the block
+ // is scanned. The stores are candidates for elimination. The map is
+ // conservatively cleared when a function call is encountered.
+ std::unordered_map<uint32_t, ir::Instruction*> var2load_;
+
+ // Set of variables whose most recent store in the current block cannot be
+ // deleted, for example, if there is a load of the variable which is
+ // dependent on the store and is not replaced and deleted by this pass,
+ // for example, a load through an access chain. A variable is removed
+ // from this set each time a new store of that variable is encountered.
+ std::unordered_set<uint32_t> pinned_vars_;
+
+ // Next unused ID
+ uint32_t next_id_;
+};
+
+} // namespace opt
+} // namespace spvtools
+
+#endif // LIBSPIRV_OPT_LOCAL_SINGLE_BLOCK_ELIM_PASS_H_
+
diff --git a/source/opt/log.h b/source/opt/log.h
index 717a3625..70ae223c 100644
--- a/source/opt/log.h
+++ b/source/opt/log.h
@@ -108,8 +108,11 @@ void Logf(const MessageConsumer& consumer, spv_message_level_t level,
return;
}
- if (size >= 0) { // The initial buffer is insufficient.
- std::vector<char> longer_message(size + 1);
+ if (size >= 0) {
+ // The initial buffer is insufficient. Allocate a buffer of a larger size,
+ // and write to it instead. Force the size to be unsigned to avoid a
+ // warning in GCC 7.1.
+ std::vector<char> longer_message(size + 1u);
snprintf(longer_message.data(), longer_message.size(), format,
std::forward<Args>(args)...);
Log(consumer, level, source, position, longer_message.data());
diff --git a/source/opt/module.cpp b/source/opt/module.cpp
index 372b70ce..dd08ca0e 100644
--- a/source/opt/module.cpp
+++ b/source/opt/module.cpp
@@ -58,6 +58,21 @@ std::vector<const Instruction*> Module::GetConstants() const {
return insts;
};
+uint32_t Module::GetGlobalValue(SpvOp opcode) const {
+ for (uint32_t i = 0; i < types_values_.size(); ++i) {
+ if (types_values_[i]->opcode() == opcode)
+ return types_values_[i]->result_id();
+ }
+ return 0;
+}
+
+void Module::AddGlobalValue(SpvOp opcode, uint32_t result_id,
+ uint32_t type_id) {
+ std::unique_ptr<ir::Instruction> newGlobal(
+ new ir::Instruction(opcode, type_id, result_id, {}));
+ AddGlobalValue(std::move(newGlobal));
+}
+
void Module::ForEachInst(const std::function<void(Instruction*)>& f,
bool run_on_debug_line_insts) {
#define DELEGATE(i) i->ForEachInst(f, run_on_debug_line_insts)
@@ -76,9 +91,9 @@ void Module::ForEachInst(const std::function<void(Instruction*)>& f,
void Module::ForEachInst(const std::function<void(const Instruction*)>& f,
bool run_on_debug_line_insts) const {
-#define DELEGATE(i) \
- static_cast<const Instruction*>(i.get()) \
- ->ForEachInst(f, run_on_debug_line_insts)
+#define DELEGATE(i) \
+ static_cast<const Instruction*>(i.get())->ForEachInst( \
+ f, run_on_debug_line_insts)
for (auto& i : capabilities_) DELEGATE(i);
for (auto& i : extensions_) DELEGATE(i);
for (auto& i : ext_inst_imports_) DELEGATE(i);
@@ -89,8 +104,8 @@ void Module::ForEachInst(const std::function<void(const Instruction*)>& f,
for (auto& i : annotations_) DELEGATE(i);
for (auto& i : types_values_) DELEGATE(i);
for (auto& i : functions_) {
- static_cast<const Function*>(i.get())
- ->ForEachInst(f, run_on_debug_line_insts);
+ static_cast<const Function*>(i.get())->ForEachInst(f,
+ run_on_debug_line_insts);
}
#undef DELEGATE
}
@@ -103,7 +118,7 @@ void Module::ToBinary(std::vector<uint32_t>* binary, bool skip_nop) const {
binary->push_back(header_.bound);
binary->push_back(header_.reserved);
- auto write_inst = [this, binary, skip_nop](const Instruction* i) {
+ auto write_inst = [binary, skip_nop](const Instruction* i) {
if (!(skip_nop && i->IsNop())) i->ToBinaryWithoutAttachedDebugInsts(binary);
};
ForEachInst(write_inst, true);
@@ -112,16 +127,28 @@ void Module::ToBinary(std::vector<uint32_t>* binary, bool skip_nop) const {
uint32_t Module::ComputeIdBound() const {
uint32_t highest = 0;
- ForEachInst([&highest](const Instruction* inst) {
- for (const auto& operand : *inst) {
- if (spvIsIdType(operand.type)) {
- highest = std::max(highest, operand.words[0]);
- }
- }
- }, true /* scan debug line insts as well */);
+ ForEachInst(
+ [&highest](const Instruction* inst) {
+ for (const auto& operand : *inst) {
+ if (spvIsIdType(operand.type)) {
+ highest = std::max(highest, operand.words[0]);
+ }
+ }
+ },
+ true /* scan debug line insts as well */);
return highest + 1;
}
+bool Module::HasCapability(uint32_t cap) {
+ for (auto& ci : capabilities_) {
+ uint32_t tcap = ci->GetSingleWordOperand(0);
+ if (tcap == cap) {
+ return true;
+ }
+ }
+ return false;
+}
+
} // namespace ir
} // namespace spvtools
diff --git a/source/opt/module.h b/source/opt/module.h
index 5d98ddcd..37d49025 100644
--- a/source/opt/module.h
+++ b/source/opt/module.h
@@ -86,6 +86,12 @@ class Module {
std::vector<Instruction*> GetConstants();
std::vector<const Instruction*> GetConstants() const;
+ // Return result id of global value with |opcode|, 0 if not present.
+ uint32_t GetGlobalValue(SpvOp opcode) const;
+
+ // Add global value with |opcode|, |result_id| and |type_id|
+ void AddGlobalValue(SpvOp opcode, uint32_t result_id, uint32_t type_id);
+
inline uint32_t id_bound() const { return header_.bound; }
// Iterators for debug instructions (excluding OpLine & OpNoLine) contained in
@@ -132,6 +138,9 @@ class Module {
// Returns 1 more than the maximum Id value mentioned in the module.
uint32_t ComputeIdBound() const;
+ // Returns true if module has capability |cap|
+ bool HasCapability(uint32_t cap);
+
private:
ModuleHeader header_; // Module header
diff --git a/source/opt/optimizer.cpp b/source/opt/optimizer.cpp
index f6dc1cdf..cf9a8251 100644
--- a/source/opt/optimizer.cpp
+++ b/source/opt/optimizer.cpp
@@ -75,7 +75,10 @@ bool Optimizer::Run(const uint32_t* original_binary,
if (module == nullptr) return false;
auto status = impl_->pass_manager.Run(module.get());
- if (status == opt::Pass::Status::SuccessWithChange) {
+ if (status == opt::Pass::Status::SuccessWithChange ||
+ (status == opt::Pass::Status::SuccessWithoutChange &&
+ (optimized_binary->data() != original_binary ||
+ optimized_binary->size() != original_binary_size))) {
optimized_binary->clear();
module->ToBinary(optimized_binary, /* skip_nop = */ true);
}
@@ -98,6 +101,17 @@ Optimizer::PassToken CreateSetSpecConstantDefaultValuePass(
MakeUnique<opt::SetSpecConstantDefaultValuePass>(id_value_map));
}
+Optimizer::PassToken CreateSetSpecConstantDefaultValuePass(
+ const std::unordered_map<uint32_t, std::vector<uint32_t>>& id_value_map) {
+ return MakeUnique<Optimizer::PassToken::Impl>(
+ MakeUnique<opt::SetSpecConstantDefaultValuePass>(id_value_map));
+}
+
+Optimizer::PassToken CreateFlattenDecorationPass() {
+ return MakeUnique<Optimizer::PassToken::Impl>(
+ MakeUnique<opt::FlattenDecorationPass>());
+}
+
Optimizer::PassToken CreateFreezeSpecConstantValuePass() {
return MakeUnique<Optimizer::PassToken::Impl>(
MakeUnique<opt::FreezeSpecConstantValuePass>());
@@ -121,5 +135,20 @@ Optimizer::PassToken CreateEliminateDeadConstantPass() {
Optimizer::PassToken CreateInlinePass() {
return MakeUnique<Optimizer::PassToken::Impl>(MakeUnique<opt::InlinePass>());
}
+
+Optimizer::PassToken CreateLocalAccessChainConvertPass() {
+ return MakeUnique<Optimizer::PassToken::Impl>(
+ MakeUnique<opt::LocalAccessChainConvertPass>());
+}
+
+Optimizer::PassToken CreateLocalSingleBlockLoadStoreElimPass() {
+ return MakeUnique<Optimizer::PassToken::Impl>(
+ MakeUnique<opt::LocalSingleBlockLoadStoreElimPass>());
+}
+
+Optimizer::PassToken CreateCompactIdsPass() {
+ return MakeUnique<Optimizer::PassToken::Impl>(
+ MakeUnique<opt::CompactIdsPass>());
+}
} // namespace spvtools
diff --git a/source/opt/passes.h b/source/opt/passes.h
index 26c55f9b..61361a7f 100644
--- a/source/opt/passes.h
+++ b/source/opt/passes.h
@@ -17,10 +17,14 @@
// A single header to include all passes.
+#include "compact_ids_pass.h"
#include "eliminate_dead_constant_pass.h"
+#include "flatten_decoration_pass.h"
#include "fold_spec_constant_op_and_composite_pass.h"
#include "inline_pass.h"
+#include "local_single_block_elim_pass.h"
#include "freeze_spec_constant_value_pass.h"
+#include "local_access_chain_convert_pass.h"
#include "null_pass.h"
#include "set_spec_constant_default_value_pass.h"
#include "strip_debug_info_pass.h"
diff --git a/source/opt/reflect.h b/source/opt/reflect.h
index 7f618153..16ea0bd4 100644
--- a/source/opt/reflect.h
+++ b/source/opt/reflect.h
@@ -15,7 +15,7 @@
#ifndef LIBSPIRV_OPT_REFLECT_H_
#define LIBSPIRV_OPT_REFLECT_H_
-#include "spirv/1.1/spirv.h"
+#include "spirv/1.2/spirv.h"
namespace spvtools {
namespace ir {
diff --git a/source/opt/set_spec_constant_default_value_pass.cpp b/source/opt/set_spec_constant_default_value_pass.cpp
index d49e2f15..02422ca4 100644
--- a/source/opt/set_spec_constant_default_value_pass.cpp
+++ b/source/opt/set_spec_constant_default_value_pass.cpp
@@ -14,6 +14,7 @@
#include "set_spec_constant_default_value_pass.h"
+#include <algorithm>
#include <cctype>
#include <cstring>
#include <tuple>
@@ -44,9 +45,9 @@ using spvutils::ParseAndEncodeNumber;
std::vector<uint32_t> ParseDefaultValueStr(const char* text,
const analysis::Type* type) {
std::vector<uint32_t> result;
- if (!strcmp(text, "true")) {
+ if (!strcmp(text, "true") && type->AsBool()) {
result.push_back(1u);
- } else if (!strcmp(text, "false")) {
+ } else if (!strcmp(text, "false") && type->AsBool()) {
result.push_back(0u);
} else {
NumberType number_type = {32, SPV_NUMBER_UNSIGNED_INT};
@@ -57,6 +58,11 @@ std::vector<uint32_t> ParseDefaultValueStr(const char* text,
} else if (const auto* FT = type->AsFloat()) {
number_type.bitwidth = FT->width();
number_type.kind = SPV_NUMBER_FLOATING;
+ } else {
+ // Does not handle types other then boolean, integer or float. Returns
+ // empty vector.
+ result.clear();
+ return result;
}
EncodeNumberStatus rc = ParseAndEncodeNumber(
text, number_type, [&result](uint32_t word) { result.push_back(word); },
@@ -69,6 +75,40 @@ std::vector<uint32_t> ParseDefaultValueStr(const char* text,
return result;
}
+// Given a bit pattern and a type, checks if the bit pattern is compatible
+// with the type. If so, returns the bit pattern, otherwise returns an empty
+// bit pattern. If the given bit pattern is empty, returns an empty bit
+// pattern. If the given type represents a SPIR-V Boolean type, the bit pattern
+// to be returned is determined with the following standard:
+// If any words in the input bit pattern are non zero, returns a bit pattern
+// with 0x1, which represents a 'true'.
+// If all words in the bit pattern are zero, returns a bit pattern with 0x0,
+// which represents a 'false'.
+std::vector<uint32_t> ParseDefaultValueBitPattern(
+ const std::vector<uint32_t>& input_bit_pattern,
+ const analysis::Type* type) {
+ std::vector<uint32_t> result;
+ if (type->AsBool()) {
+ if (std::any_of(input_bit_pattern.begin(), input_bit_pattern.end(),
+ [](uint32_t i) { return i != 0; })) {
+ result.push_back(1u);
+ } else {
+ result.push_back(0u);
+ }
+ return result;
+ } else if (const auto* IT = type->AsInteger()) {
+ if (IT->width() == input_bit_pattern.size() * sizeof(uint32_t) * 8) {
+ return std::vector<uint32_t>(input_bit_pattern);
+ }
+ } else if (const auto* FT = type->AsFloat()) {
+ if (FT->width() == input_bit_pattern.size() * sizeof(uint32_t) * 8) {
+ return std::vector<uint32_t>(input_bit_pattern);
+ }
+ }
+ result.clear();
+ return result;
+}
+
// Returns true if the given instruction's result id could have a SpecId
// decoration.
bool CanHaveSpecIdDecoration(const ir::Instruction& inst) {
@@ -200,15 +240,34 @@ Pass::Status SetSpecConstantDefaultValuePass::Process(ir::Module* module) {
}
if (!spec_inst) continue;
- // Search for the new default value for this spec id.
- auto iter = spec_id_to_value_.find(spec_id);
- if (iter == spec_id_to_value_.end()) continue;
+ // Get the default value bit pattern for this spec id.
+ std::vector<uint32_t> bit_pattern;
+
+ if (spec_id_to_value_str_.size() != 0) {
+ // Search for the new string-form default value for this spec id.
+ auto iter = spec_id_to_value_str_.find(spec_id);
+ if (iter == spec_id_to_value_str_.end()) {
+ continue;
+ }
+
+ // Gets the string of the default value and parses it to bit pattern
+ // with the type of the spec constant.
+ const std::string& default_value_str = iter->second;
+ bit_pattern = ParseDefaultValueStr(default_value_str.c_str(),
+ type_mgr.GetType(spec_inst->type_id()));
+
+ } else {
+ // Search for the new bit-pattern-form default value for this spec id.
+ auto iter = spec_id_to_value_bit_pattern_.find(spec_id);
+ if (iter == spec_id_to_value_bit_pattern_.end()) {
+ continue;
+ }
+
+ // Gets the bit-pattern of the default value from the map directly.
+ bit_pattern = ParseDefaultValueBitPattern(
+ iter->second, type_mgr.GetType(spec_inst->type_id()));
+ }
- // Gets the string of the default value and parses it to bit pattern
- // with the type of the spec constant.
- const std::string& default_value_str = iter->second;
- std::vector<uint32_t> bit_pattern = ParseDefaultValueStr(
- default_value_str.c_str(), type_mgr.GetType(spec_inst->type_id()));
if (bit_pattern.empty()) continue;
// Update the operand bit patterns of the spec constant defining
diff --git a/source/opt/set_spec_constant_default_value_pass.h b/source/opt/set_spec_constant_default_value_pass.h
index ad6b1424..15fcc0ab 100644
--- a/source/opt/set_spec_constant_default_value_pass.h
+++ b/source/opt/set_spec_constant_default_value_pass.h
@@ -29,14 +29,25 @@ namespace opt {
class SetSpecConstantDefaultValuePass : public Pass {
public:
using SpecIdToValueStrMap = std::unordered_map<uint32_t, std::string>;
+ using SpecIdToValueBitPatternMap =
+ std::unordered_map<uint32_t, std::vector<uint32_t>>;
using SpecIdToInstMap = std::unordered_map<uint32_t, ir::Instruction*>;
- // Constructs a pass instance with a map from spec ids to default values.
+ // Constructs a pass instance with a map from spec ids to default values
+ // in the form of string.
explicit SetSpecConstantDefaultValuePass(
const SpecIdToValueStrMap& default_values)
- : spec_id_to_value_(default_values) {}
+ : spec_id_to_value_str_(default_values), spec_id_to_value_bit_pattern_() {}
explicit SetSpecConstantDefaultValuePass(SpecIdToValueStrMap&& default_values)
- : spec_id_to_value_(std::move(default_values)) {}
+ : spec_id_to_value_str_(std::move(default_values)), spec_id_to_value_bit_pattern_() {}
+
+ // Constructs a pass instance with a map from spec ids to default values in
+ // the form of bit pattern.
+ explicit SetSpecConstantDefaultValuePass(
+ const SpecIdToValueBitPatternMap& default_values)
+ : spec_id_to_value_str_(), spec_id_to_value_bit_pattern_(default_values) {}
+ explicit SetSpecConstantDefaultValuePass(SpecIdToValueBitPatternMap&& default_values)
+ : spec_id_to_value_str_(), spec_id_to_value_bit_pattern_(std::move(default_values)) {}
const char* name() const override { return "set-spec-const-default-value"; }
Status Process(ir::Module*) override;
@@ -78,8 +89,14 @@ class SetSpecConstantDefaultValuePass : public Pass {
const char* str);
private:
- // The mapping from spec ids to their default values to be set.
- const SpecIdToValueStrMap spec_id_to_value_;
+ // The mappings from spec ids to default values. Two maps are defined here,
+ // each to be used for one specific form of the default values. Only one of
+ // them will be populated in practice.
+
+ // The mapping from spec ids to their string-form default values to be set.
+ const SpecIdToValueStrMap spec_id_to_value_str_;
+ // The mapping from spec ids to their bitpattern-form default values to be set.
+ const SpecIdToValueBitPatternMap spec_id_to_value_bit_pattern_;
};
} // namespace opt
diff --git a/source/opt/type_manager.cpp b/source/opt/type_manager.cpp
index ed8a13c7..8a125278 100644
--- a/source/opt/type_manager.cpp
+++ b/source/opt/type_manager.cpp
@@ -204,7 +204,7 @@ void TypeManager::AttachIfTypeDecoration(const ir::Instruction& inst) {
data.push_back(inst.GetSingleWordOperand(i));
}
if (Struct* st = target_type->AsStruct()) {
- st->AddMemeberDecoration(index, std::move(data));
+ st->AddMemberDecoration(index, std::move(data));
} else {
SPIRV_UNIMPLEMENTED(consumer_, "OpMemberDecorate non-struct type");
}
diff --git a/source/opt/types.cpp b/source/opt/types.cpp
index 285c1488..1d8cad24 100644
--- a/source/opt/types.cpp
+++ b/source/opt/types.cpp
@@ -220,8 +220,8 @@ Struct::Struct(const std::vector<Type*>& types) : element_types_(types) {
}
}
-void Struct::AddMemeberDecoration(uint32_t index,
- std::vector<uint32_t>&& decoration) {
+void Struct::AddMemberDecoration(uint32_t index,
+ std::vector<uint32_t>&& decoration) {
if (index >= element_types_.size()) {
assert(0 && "index out of bound");
return;
diff --git a/source/opt/types.h b/source/opt/types.h
index a46e4114..b6b62c53 100644
--- a/source/opt/types.h
+++ b/source/opt/types.h
@@ -22,7 +22,7 @@
#include <vector>
#include "spirv-tools/libspirv.h"
-#include "spirv/1.1/spirv.h"
+#include "spirv/1.2/spirv.h"
namespace spvtools {
namespace opt {
@@ -262,7 +262,9 @@ class Struct : public Type {
Struct(const std::vector<Type*>& element_types);
Struct(const Struct&) = default;
- void AddMemeberDecoration(uint32_t index, std::vector<uint32_t>&& decoration);
+ // Adds a decoration to the member at the given index. The first word is the
+ // decoration enum, and the remaining words, if any, are its operands.
+ void AddMemberDecoration(uint32_t index, std::vector<uint32_t>&& decoration);
bool IsSame(Type* that) const override;
std::string str() const override;
diff --git a/source/print.cpp b/source/print.cpp
index c147a660..ff73b3d4 100644
--- a/source/print.cpp
+++ b/source/print.cpp
@@ -14,7 +14,7 @@
#include "print.h"
-#if defined(SPIRV_ANDROID) || defined(SPIRV_LINUX) || defined(SPIRV_MAC)
+#if defined(SPIRV_ANDROID) || defined(SPIRV_LINUX) || defined(SPIRV_MAC) || defined(SPIRV_FREEBSD)
namespace libspirv {
clr::reset::operator const char*() { return "\x1b[0m"; }
diff --git a/source/spirv_constant.h b/source/spirv_constant.h
index d807411c..c70ade10 100644
--- a/source/spirv_constant.h
+++ b/source/spirv_constant.h
@@ -16,7 +16,7 @@
#define LIBSPIRV_SPIRV_CONSTANT_H_
#include "spirv-tools/libspirv.h"
-#include "spirv/1.1/spirv.h"
+#include "spirv/1.2/spirv.h"
// Version number macros.
diff --git a/source/spirv_definition.h b/source/spirv_definition.h
index e443809a..b82bda16 100644
--- a/source/spirv_definition.h
+++ b/source/spirv_definition.h
@@ -17,7 +17,7 @@
#include <cstdint>
-#include "spirv/1.1/spirv.h"
+#include "spirv/1.2/spirv.h"
#define spvIsInBitfield(value, bitfield) ((value) == ((value)&bitfield))
diff --git a/source/spirv_stats.cpp b/source/spirv_stats.cpp
new file mode 100644
index 00000000..2186e0d7
--- /dev/null
+++ b/source/spirv_stats.cpp
@@ -0,0 +1,217 @@
+// Copyright (c) 2017 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "spirv_stats.h"
+
+#include <cassert>
+
+#include <algorithm>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "binary.h"
+#include "diagnostic.h"
+#include "enum_string_mapping.h"
+#include "extensions.h"
+#include "instruction.h"
+#include "opcode.h"
+#include "operand.h"
+#include "spirv-tools/libspirv.h"
+#include "spirv_endian.h"
+#include "spirv_validator_options.h"
+#include "validate.h"
+#include "val/instruction.h"
+#include "val/validation_state.h"
+
+using libspirv::Instruction;
+using libspirv::SpirvStats;
+using libspirv::ValidationState_t;
+
+namespace {
+
+// Helper class for stats aggregation. Receives as in/out parameter.
+// Constructs ValidationState and updates it by running validator for each
+// instruction.
+class StatsAggregator {
+ public:
+ StatsAggregator(SpirvStats* in_out_stats, const spv_const_context context) {
+ stats_ = in_out_stats;
+ vstate_.reset(new ValidationState_t(context, &validator_options_));
+ }
+
+ // Collects header statistics and sets correct id_bound.
+ spv_result_t ProcessHeader(
+ spv_endianness_t /* endian */, uint32_t /* magic */,
+ uint32_t version, uint32_t generator, uint32_t id_bound,
+ uint32_t /* schema */) {
+ vstate_->setIdBound(id_bound);
+ ++stats_->version_hist[version];
+ ++stats_->generator_hist[generator];
+ return SPV_SUCCESS;
+ }
+
+ // Runs validator to validate the instruction and update vstate_,
+ // then procession the instruction to collect stats.
+ spv_result_t ProcessInstruction(const spv_parsed_instruction_t* inst) {
+ const spv_result_t validation_result =
+ spvtools::ValidateInstructionAndUpdateValidationState(vstate_.get(), inst);
+ if (validation_result != SPV_SUCCESS)
+ return validation_result;
+
+ ProcessOpcode();
+ ProcessCapability();
+ ProcessExtension();
+ ProcessConstant();
+
+ return SPV_SUCCESS;
+ }
+
+ // Collects OpCapability statistics.
+ void ProcessCapability() {
+ const Instruction& inst = GetCurrentInstruction();
+ if (inst.opcode() != SpvOpCapability) return;
+ const uint32_t capability = inst.word(inst.operands()[0].offset);
+ ++stats_->capability_hist[capability];
+ }
+
+ // Collects OpExtension statistics.
+ void ProcessExtension() {
+ const Instruction& inst = GetCurrentInstruction();
+ if (inst.opcode() != SpvOpExtension) return;
+ const std::string extension = libspirv::GetExtensionString(&inst.c_inst());
+ ++stats_->extension_hist[extension];
+ }
+
+ // Collects OpCode statistics.
+ void ProcessOpcode() {
+ auto inst_it = vstate_->ordered_instructions().rbegin();
+ const SpvOp opcode = inst_it->opcode();
+ ++stats_->opcode_hist[opcode];
+
+ ++inst_it;
+ auto step_it = stats_->opcode_markov_hist.begin();
+ for (; inst_it != vstate_->ordered_instructions().rend() &&
+ step_it != stats_->opcode_markov_hist.end(); ++inst_it, ++step_it) {
+ auto& hist = (*step_it)[inst_it->opcode()];
+ ++hist[opcode];
+ }
+ }
+
+ // Collects OpConstant statistics.
+ void ProcessConstant() {
+ const Instruction& inst = GetCurrentInstruction();
+ if (inst.opcode() != SpvOpConstant) return;
+ const uint32_t type_id = inst.GetOperandAs<uint32_t>(0);
+ const auto type_decl_it = vstate_->all_definitions().find(type_id);
+ assert(type_decl_it != vstate_->all_definitions().end());
+ const Instruction& type_decl_inst = *type_decl_it->second;
+ const SpvOp type_op = type_decl_inst.opcode();
+ if (type_op == SpvOpTypeInt) {
+ const uint32_t bit_width = type_decl_inst.GetOperandAs<uint32_t>(1);
+ const uint32_t is_signed = type_decl_inst.GetOperandAs<uint32_t>(2);
+ assert(is_signed == 0 || is_signed == 1);
+ if (bit_width == 16) {
+ if (is_signed)
+ ++stats_->s16_constant_hist[inst.GetOperandAs<int16_t>(2)];
+ else
+ ++stats_->u16_constant_hist[inst.GetOperandAs<uint16_t>(2)];
+ } else if (bit_width == 32) {
+ if (is_signed)
+ ++stats_->s32_constant_hist[inst.GetOperandAs<int32_t>(2)];
+ else
+ ++stats_->u32_constant_hist[inst.GetOperandAs<uint32_t>(2)];
+ } else if (bit_width == 64) {
+ if (is_signed)
+ ++stats_->s64_constant_hist[inst.GetOperandAs<int64_t>(2)];
+ else
+ ++stats_->u64_constant_hist[inst.GetOperandAs<uint64_t>(2)];
+ } else {
+ assert(false && "TypeInt bit width is not 16, 32 or 64");
+ }
+ } else if (type_op == SpvOpTypeFloat) {
+ const uint32_t bit_width = type_decl_inst.GetOperandAs<uint32_t>(1);
+ if (bit_width == 32) {
+ ++stats_->f32_constant_hist[inst.GetOperandAs<float>(2)];
+ } else if (bit_width == 64) {
+ ++stats_->f64_constant_hist[inst.GetOperandAs<double>(2)];
+ } else {
+ assert(bit_width == 16);
+ }
+ }
+ }
+
+ SpirvStats* stats() {
+ return stats_;
+ }
+
+ private:
+ // Returns the current instruction (the one last processed by the validator).
+ const Instruction& GetCurrentInstruction() const {
+ return vstate_->ordered_instructions().back();
+ }
+
+ SpirvStats* stats_;
+ spv_validator_options_t validator_options_;
+ std::unique_ptr<ValidationState_t> vstate_;
+};
+
+spv_result_t ProcessHeader(
+ void* user_data, spv_endianness_t endian, uint32_t magic,
+ uint32_t version, uint32_t generator, uint32_t id_bound,
+ uint32_t schema) {
+ StatsAggregator* stats_aggregator =
+ reinterpret_cast<StatsAggregator*>(user_data);
+ return stats_aggregator->ProcessHeader(
+ endian, magic, version, generator, id_bound, schema);
+}
+
+spv_result_t ProcessInstruction(
+ void* user_data, const spv_parsed_instruction_t* inst) {
+ StatsAggregator* stats_aggregator =
+ reinterpret_cast<StatsAggregator*>(user_data);
+ return stats_aggregator->ProcessInstruction(inst);
+}
+
+} // namespace
+
+namespace libspirv {
+
+spv_result_t AggregateStats(
+ const spv_context_t& context, const uint32_t* words, const size_t num_words,
+ spv_diagnostic* pDiagnostic, SpirvStats* stats) {
+ spv_const_binary_t binary = {words, num_words};
+
+ spv_endianness_t endian;
+ spv_position_t position = {};
+ if (spvBinaryEndianness(&binary, &endian)) {
+ return libspirv::DiagnosticStream(position, context.consumer,
+ SPV_ERROR_INVALID_BINARY)
+ << "Invalid SPIR-V magic number.";
+ }
+
+ spv_header_t header;
+ if (spvBinaryHeaderGet(&binary, endian, &header)) {
+ return libspirv::DiagnosticStream(position, context.consumer,
+ SPV_ERROR_INVALID_BINARY)
+ << "Invalid SPIR-V header.";
+ }
+
+ StatsAggregator stats_aggregator(stats, &context);
+
+ return spvBinaryParse(&context, &stats_aggregator, words, num_words,
+ ProcessHeader, ProcessInstruction, pDiagnostic);
+}
+
+} // namespace libspirv
diff --git a/source/spirv_stats.h b/source/spirv_stats.h
new file mode 100644
index 00000000..9c7a41aa
--- /dev/null
+++ b/source/spirv_stats.h
@@ -0,0 +1,88 @@
+// Copyright (c) 2017 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef LIBSPIRV_SPIRV_STATS_H_
+#define LIBSPIRV_SPIRV_STATS_H_
+
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "spirv-tools/libspirv.hpp"
+
+namespace libspirv {
+
+struct SpirvStats {
+ // Version histogram, version_word -> count.
+ std::unordered_map<uint32_t, uint32_t> version_hist;
+
+ // Generator histogram, generator_word -> count.
+ std::unordered_map<uint32_t, uint32_t> generator_hist;
+
+ // Capability histogram, SpvCapabilityXXX -> count.
+ std::unordered_map<uint32_t, uint32_t> capability_hist;
+
+ // Extension histogram, extension_string -> count.
+ std::unordered_map<std::string, uint32_t> extension_hist;
+
+ // Opcode histogram, SpvOpXXX -> count.
+ std::unordered_map<uint32_t, uint32_t> opcode_hist;
+
+ // OpConstant u16 histogram, value -> count.
+ std::unordered_map<uint16_t, uint32_t> u16_constant_hist;
+
+ // OpConstant u32 histogram, value -> count.
+ std::unordered_map<uint32_t, uint32_t> u32_constant_hist;
+
+ // OpConstant u64 histogram, value -> count.
+ std::unordered_map<uint64_t, uint32_t> u64_constant_hist;
+
+ // OpConstant s16 histogram, value -> count.
+ std::unordered_map<int16_t, uint32_t> s16_constant_hist;
+
+ // OpConstant s32 histogram, value -> count.
+ std::unordered_map<int32_t, uint32_t> s32_constant_hist;
+
+ // OpConstant s64 histogram, value -> count.
+ std::unordered_map<int64_t, uint32_t> s64_constant_hist;
+
+ // OpConstant f32 histogram, value -> count.
+ std::unordered_map<float, uint32_t> f32_constant_hist;
+
+ // OpConstant f64 histogram, value -> count.
+ std::unordered_map<double, uint32_t> f64_constant_hist;
+
+ // Used to collect statistics on opcodes triggering other opcodes.
+ // Container scheme: gap between instructions -> cue opcode -> later opcode
+ // -> count.
+ // For example opcode_markov_hist[2][OpFMul][OpFAdd] corresponds to
+ // the number of times an OpMul appears, followed by 2 other instructions,
+ // followed by OpFAdd.
+ // opcode_markov_hist[0][OpFMul][OpFAdd] corresponds to how many times
+ // OpFMul appears, directly followed by OpFAdd.
+ // The size of the outer std::vector also serves as an input parameter,
+ // determining how many steps will be collected.
+ // I.e. do opcode_markov_hist.resize(1) to collect data for one step only.
+ std::vector<std::unordered_map<uint32_t,
+ std::unordered_map<uint32_t, uint32_t>>> opcode_markov_hist;
+};
+
+// Aggregates existing |stats| with new stats extracted from |binary|.
+spv_result_t AggregateStats(
+ const spv_context_t& context, const uint32_t* words, const size_t num_words,
+ spv_diagnostic* pDiagnostic, SpirvStats* stats);
+
+} // namespace libspirv
+
+#endif // LIBSPIRV_SPIRV_STATS_H_
diff --git a/source/spirv_target_env.cpp b/source/spirv_target_env.cpp
index 8a459a37..ed47f525 100644
--- a/source/spirv_target_env.cpp
+++ b/source/spirv_target_env.cpp
@@ -40,6 +40,8 @@ const char* spvTargetEnvDescription(spv_target_env env) {
return "SPIR-V 1.0 (under OpenCL 4.3 semantics)";
case SPV_ENV_OPENGL_4_5:
return "SPIR-V 1.0 (under OpenCL 4.5 semantics)";
+ case SPV_ENV_UNIVERSAL_1_2:
+ return "SPIR-V 1.2";
}
assert(0 && "Unhandled SPIR-V target environment");
return "";
@@ -57,8 +59,10 @@ uint32_t spvVersionForTargetEnv(spv_target_env env) {
case SPV_ENV_OPENGL_4_5:
return SPV_SPIRV_VERSION_WORD(1, 0);
case SPV_ENV_UNIVERSAL_1_1:
- case SPV_ENV_OPENCL_2_2:
return SPV_SPIRV_VERSION_WORD(1, 1);
+ case SPV_ENV_UNIVERSAL_1_2:
+ case SPV_ENV_OPENCL_2_2:
+ return SPV_SPIRV_VERSION_WORD(1, 2);
}
assert(0 && "Unhandled SPIR-V target environment");
return SPV_SPIRV_VERSION_WORD(0, 0);
@@ -77,6 +81,9 @@ bool spvParseTargetEnv(const char* s, spv_target_env* env) {
} else if (match("spv1.1")) {
if (env) *env = SPV_ENV_UNIVERSAL_1_1;
return true;
+ } else if (match("spv1.2")) {
+ if (env) *env = SPV_ENV_UNIVERSAL_1_2;
+ return true;
} else if (match("opencl2.1")) {
if (env) *env = SPV_ENV_OPENCL_2_1;
return true;
diff --git a/source/table.cpp b/source/table.cpp
index 8f858024..b8fb809c 100644
--- a/source/table.cpp
+++ b/source/table.cpp
@@ -28,6 +28,7 @@ spv_context spvContextCreate(spv_target_env env) {
case SPV_ENV_OPENGL_4_2:
case SPV_ENV_OPENGL_4_3:
case SPV_ENV_OPENGL_4_5:
+ case SPV_ENV_UNIVERSAL_1_2:
break;
default:
return nullptr;
diff --git a/source/table.h b/source/table.h
index 340cdb71..a7dffaaf 100644
--- a/source/table.h
+++ b/source/table.h
@@ -15,9 +15,7 @@
#ifndef LIBSPIRV_TABLE_H_
#define LIBSPIRV_TABLE_H_
-#include <string>
-
-#include "spirv/1.1/spirv.h"
+#include "spirv/1.2/spirv.h"
#include "extensions.h"
#include "message.h"
diff --git a/source/text.cpp b/source/text.cpp
index 6e68ac20..6a6846ea 100644
--- a/source/text.cpp
+++ b/source/text.cpp
@@ -659,13 +659,57 @@ spv_result_t SetHeader(spv_target_env env, const uint32_t bound,
return SPV_SUCCESS;
}
+// Collects all numeric ids in the module source into |numeric_ids|.
+// This function is essentially a dry-run of spvTextToBinary.
+spv_result_t GetNumericIds(const libspirv::AssemblyGrammar& grammar,
+ const spvtools::MessageConsumer& consumer,
+ const spv_text text,
+ std::set<uint32_t>* numeric_ids) {
+ libspirv::AssemblyContext context(text, consumer);
+
+ if (!text->str) return context.diagnostic() << "Missing assembly text.";
+
+ if (!grammar.isValid()) {
+ return SPV_ERROR_INVALID_TABLE;
+ }
+
+ // Skip past whitespace and comments.
+ context.advance();
+
+ while (context.hasText()) {
+ spv_instruction_t inst;
+
+ if (spvTextEncodeOpcode(grammar, &context, &inst)) {
+ return SPV_ERROR_INVALID_TEXT;
+ }
+
+ if (context.advance()) break;
+ }
+
+ *numeric_ids = context.GetNumericIds();
+ return SPV_SUCCESS;
+}
+
// Translates a given assembly language module into binary form.
// If a diagnostic is generated, it is not yet marked as being
// for a text-based input.
-spv_result_t spvTextToBinaryInternal(const libspirv::AssemblyGrammar& grammar,
- const spvtools::MessageConsumer& consumer,
- const spv_text text, spv_binary* pBinary) {
- libspirv::AssemblyContext context(text, consumer);
+spv_result_t spvTextToBinaryInternal(
+ const libspirv::AssemblyGrammar& grammar,
+ const spvtools::MessageConsumer& consumer, const spv_text text,
+ const uint32_t options, spv_binary* pBinary) {
+ // The ids in this set will have the same values both in source and binary.
+ // All other ids will be generated by filling in the gaps.
+ std::set<uint32_t> ids_to_preserve;
+
+ if (options & SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS) {
+ // Collect all numeric ids from the source into ids_to_preserve.
+ const spv_result_t result =
+ GetNumericIds(grammar, consumer, text, &ids_to_preserve);
+ if (result != SPV_SUCCESS) return result;
+ }
+
+ libspirv::AssemblyContext context(text, consumer, std::move(ids_to_preserve));
+
if (!text->str) return context.diagnostic() << "Missing assembly text.";
if (!grammar.isValid()) {
@@ -725,6 +769,15 @@ spv_result_t spvTextToBinary(const spv_const_context context,
const char* input_text,
const size_t input_text_size, spv_binary* pBinary,
spv_diagnostic* pDiagnostic) {
+ return spvTextToBinaryWithOptions(
+ context, input_text, input_text_size, SPV_BINARY_TO_TEXT_OPTION_NONE,
+ pBinary, pDiagnostic);
+}
+
+spv_result_t spvTextToBinaryWithOptions(
+ const spv_const_context context, const char* input_text,
+ const size_t input_text_size, const uint32_t options, spv_binary* pBinary,
+ spv_diagnostic* pDiagnostic) {
spv_context_t hijack_context = *context;
if (pDiagnostic) {
*pDiagnostic = nullptr;
@@ -734,8 +787,8 @@ spv_result_t spvTextToBinary(const spv_const_context context,
spv_text_t text = {input_text, input_text_size};
libspirv::AssemblyGrammar grammar(&hijack_context);
- spv_result_t result =
- spvTextToBinaryInternal(grammar, hijack_context.consumer, &text, pBinary);
+ spv_result_t result = spvTextToBinaryInternal(
+ grammar, hijack_context.consumer, &text, options, pBinary);
if (pDiagnostic && *pDiagnostic) (*pDiagnostic)->isTextSource = true;
return result;
diff --git a/source/text_handler.cpp b/source/text_handler.cpp
index 8724b1cd..1806926e 100644
--- a/source/text_handler.cpp
+++ b/source/text_handler.cpp
@@ -14,6 +14,7 @@
#include "text_handler.h"
+#include <algorithm>
#include <cassert>
#include <cstdlib>
#include <cstring>
@@ -154,11 +155,33 @@ const IdType kUnknownType = {0, false, IdTypeClass::kBottom};
// This represents all of the data that is only valid for the duration of
// a single compilation.
uint32_t AssemblyContext::spvNamedIdAssignOrGet(const char* textValue) {
- if (named_ids_.end() == named_ids_.find(textValue)) {
- named_ids_[std::string(textValue)] = bound_++;
+ if (!ids_to_preserve_.empty()) {
+ uint32_t id = 0;
+ if (spvutils::ParseNumber(textValue, &id)) {
+ if (ids_to_preserve_.find(id) != ids_to_preserve_.end()) {
+ bound_ = std::max(bound_, id + 1);
+ return id;
+ }
+ }
+ }
+
+ const auto it = named_ids_.find(textValue);
+ if (it == named_ids_.end()) {
+ uint32_t id = next_id_++;
+ if (!ids_to_preserve_.empty()) {
+ while (ids_to_preserve_.find(id) != ids_to_preserve_.end()) {
+ id = next_id_++;
+ }
+ }
+
+ named_ids_.emplace(textValue, id);
+ bound_ = std::max(bound_, id + 1);
+ return id;
}
- return named_ids_[textValue];
+
+ return it->second;
}
+
uint32_t AssemblyContext::getBound() const { return bound_; }
spv_result_t AssemblyContext::advance() {
@@ -362,4 +385,14 @@ spv_ext_inst_type_t AssemblyContext::getExtInstTypeForId(uint32_t id) const {
return std::get<1>(*type);
}
+std::set<uint32_t> AssemblyContext::GetNumericIds() const {
+ std::set<uint32_t> ids;
+ for (const auto& kv : named_ids_) {
+ uint32_t id;
+ if (spvutils::ParseNumber(kv.first.c_str(), &id))
+ ids.insert(id);
+ }
+ return ids;
+}
+
} // namespace libspirv
diff --git a/source/text_handler.h b/source/text_handler.h
index 1bd004c1..1e17948d 100644
--- a/source/text_handler.h
+++ b/source/text_handler.h
@@ -117,8 +117,10 @@ class ClampToZeroIfUnsignedType<
// Encapsulates the data used during the assembly of a SPIR-V module.
class AssemblyContext {
public:
- AssemblyContext(spv_text text, const spvtools::MessageConsumer& consumer)
- : current_position_({}), consumer_(consumer), text_(text), bound_(1) {}
+ AssemblyContext(spv_text text, const spvtools::MessageConsumer& consumer,
+ std::set<uint32_t>&& ids_to_preserve = std::set<uint32_t>())
+ : current_position_({}), consumer_(consumer), text_(text), bound_(1),
+ next_id_(1), ids_to_preserve_(std::move(ids_to_preserve)) {}
// Assigns a new integer value to the given text ID, or returns the previously
// assigned integer value if the ID has been seen before.
@@ -224,6 +226,11 @@ class AssemblyContext {
// id is not the id for an extended instruction type.
spv_ext_inst_type_t getExtInstTypeForId(uint32_t id) const;
+ // Returns a set consisting of each ID generated by spvNamedIdAssignOrGet from
+ // a numeric ID text representation. For example, generated from "%12" but not
+ // from "%foo".
+ std::set<uint32_t> GetNumericIds() const;
+
private:
// Maps ID names to their corresponding numerical ids.
using spv_named_id_table = std::unordered_map<std::string, uint32_t>;
@@ -241,6 +248,8 @@ class AssemblyContext {
spvtools::MessageConsumer consumer_;
spv_text text_;
uint32_t bound_;
+ uint32_t next_id_;
+ std::set<uint32_t> ids_to_preserve_;
};
}
#endif // _LIBSPIRV_TEXT_HANDLER_H_
diff --git a/source/util/bit_stream.cpp b/source/util/bit_stream.cpp
new file mode 100644
index 00000000..5dac5638
--- /dev/null
+++ b/source/util/bit_stream.cpp
@@ -0,0 +1,387 @@
+// Copyright (c) 2017 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <algorithm>
+#include <cassert>
+#include <cstring>
+#include <sstream>
+#include <type_traits>
+
+#include "util/bit_stream.h"
+
+namespace spvutils {
+
+namespace {
+
+// Returns if the system is little-endian. Unfortunately only works during
+// runtime.
+bool IsLittleEndian() {
+ // This constant value allows the detection of the host machine's endianness.
+ // Accessing it as an array of bytes is valid due to C++11 section 3.10
+ // paragraph 10.
+ static const uint16_t kFF00 = 0xff00;
+ return reinterpret_cast<const unsigned char*>(&kFF00)[0] == 0;
+}
+
+// Copies uint8_t buffer to a uint64_t buffer.
+// Motivation: casting uint64_t* to uint8_t* is ok. Casting in the other
+// direction is only advisable if uint8_t* is aligned to 64-bit word boundary.
+std::vector<uint64_t> ToBuffer64(const std::vector<uint8_t>& in) {
+ std::vector<uint64_t> out;
+ out.resize((in.size() + 7) / 8, 0);
+ memcpy(out.data(), in.data(), in.size());
+ return out;
+}
+
+// Returns uint64_t containing the same bits as |val|.
+// Type size must be less than 8 bytes.
+template <typename T>
+uint64_t ToU64(T val) {
+ static_assert(sizeof(T) <= 8, "Type size too big");
+ uint64_t val64 = 0;
+ std::memcpy(&val64, &val, sizeof(T));
+ return val64;
+}
+
+// Returns value of type T containing the same bits as |val64|.
+// Type size must be less than 8 bytes. Upper (unused) bits of |val64| must be
+// zero (irrelevant, but is checked with assertion).
+template <typename T>
+T FromU64(uint64_t val64) {
+ assert(sizeof(T) == 8 || (val64 >> (sizeof(T) * 8)) == 0);
+ static_assert(sizeof(T) <= 8, "Type size too big");
+ T val = 0;
+ std::memcpy(&val, &val64, sizeof(T));
+ return val;
+}
+
+// Writes bits from |val| to |writer| in chunks of size |chunk_length|.
+// Signal bit is used to signal if the reader should expect another chunk:
+// 0 - no more chunks to follow
+// 1 - more chunks to follow
+// If number of written bits reaches |max_payload| last chunk is truncated.
+void WriteVariableWidthInternal(BitWriterInterface* writer, uint64_t val,
+ size_t chunk_length, size_t max_payload) {
+ assert(chunk_length > 0);
+ assert(chunk_length < max_payload);
+ assert(max_payload == 64 || (val >> max_payload) == 0);
+
+ if (val == 0) {
+ writer->WriteBits(0, chunk_length + 1);
+ return;
+ }
+
+ size_t payload_written = 0;
+
+ while (val) {
+ if (payload_written + chunk_length >= max_payload) {
+ // This has to be the last chunk.
+ // There is no need for the signal bit and the chunk can be truncated.
+ const size_t left_to_write = max_payload - payload_written;
+ assert((val >> left_to_write) == 0);
+ writer->WriteBits(val, left_to_write);
+ break;
+ }
+
+ writer->WriteBits(val, chunk_length);
+ payload_written += chunk_length;
+ val = val >> chunk_length;
+
+ // Write a single bit to signal if there is more to come.
+ writer->WriteBits(val ? 1 : 0, 1);
+ }
+}
+
+// Reads data written with WriteVariableWidthInternal. |chunk_length| and
+// |max_payload| should be identical to those used to write the data.
+// Returns false if the stream ends prematurely.
+bool ReadVariableWidthInternal(BitReaderInterface* reader, uint64_t* val,
+ size_t chunk_length, size_t max_payload) {
+ assert(chunk_length > 0);
+ assert(chunk_length <= max_payload);
+ size_t payload_read = 0;
+
+ while (payload_read + chunk_length < max_payload) {
+ uint64_t bits = 0;
+ if (reader->ReadBits(&bits, chunk_length) != chunk_length)
+ return false;
+
+ *val |= bits << payload_read;
+ payload_read += chunk_length;
+
+ uint64_t more_to_come = 0;
+ if (reader->ReadBits(&more_to_come, 1) != 1)
+ return false;
+
+ if (!more_to_come) {
+ return true;
+ }
+ }
+
+ // Need to read the last chunk which may be truncated. No signal bit follows.
+ uint64_t bits = 0;
+ const size_t left_to_read = max_payload - payload_read;
+ if (reader->ReadBits(&bits, left_to_read) != left_to_read)
+ return false;
+
+ *val |= bits << payload_read;
+ return true;
+}
+
+// Calls WriteVariableWidthInternal with the right max_payload argument.
+template <typename T>
+void WriteVariableWidthUnsigned(BitWriterInterface* writer, T val,
+ size_t chunk_length) {
+ static_assert(std::is_unsigned<T>::value, "Type must be unsigned");
+ static_assert(std::is_integral<T>::value, "Type must be integral");
+ WriteVariableWidthInternal(writer, val, chunk_length, sizeof(T) * 8);
+}
+
+// Calls ReadVariableWidthInternal with the right max_payload argument.
+template <typename T>
+bool ReadVariableWidthUnsigned(BitReaderInterface* reader, T* val,
+ size_t chunk_length) {
+ static_assert(std::is_unsigned<T>::value, "Type must be unsigned");
+ static_assert(std::is_integral<T>::value, "Type must be integral");
+ uint64_t val64 = 0;
+ if (!ReadVariableWidthInternal(reader, &val64, chunk_length, sizeof(T) * 8))
+ return false;
+ *val = static_cast<T>(val64);
+ assert(*val == val64);
+ return true;
+}
+
+// Encodes signed |val| to an unsigned value and calls
+// WriteVariableWidthInternal with the right max_payload argument.
+template <typename T>
+void WriteVariableWidthSigned(BitWriterInterface* writer, T val,
+ size_t chunk_length, size_t zigzag_exponent) {
+ static_assert(std::is_signed<T>::value, "Type must be signed");
+ static_assert(std::is_integral<T>::value, "Type must be integral");
+ WriteVariableWidthInternal(writer, EncodeZigZag(val, zigzag_exponent),
+ chunk_length, sizeof(T) * 8);
+}
+
+// Calls ReadVariableWidthInternal with the right max_payload argument
+// and decodes the value.
+template <typename T>
+bool ReadVariableWidthSigned(BitReaderInterface* reader, T* val,
+ size_t chunk_length, size_t zigzag_exponent) {
+ static_assert(std::is_signed<T>::value, "Type must be signed");
+ static_assert(std::is_integral<T>::value, "Type must be integral");
+ uint64_t encoded = 0;
+ if (!ReadVariableWidthInternal(reader, &encoded, chunk_length, sizeof(T) * 8))
+ return false;
+
+ const int64_t decoded = DecodeZigZag(encoded, zigzag_exponent);
+
+ *val = static_cast<T>(decoded);
+ assert(*val == decoded);
+ return true;
+}
+
+} // namespace
+
+void BitWriterInterface::WriteVariableWidthU64(uint64_t val,
+ size_t chunk_length) {
+ WriteVariableWidthUnsigned(this, val, chunk_length);
+}
+
+void BitWriterInterface::WriteVariableWidthU32(uint32_t val,
+ size_t chunk_length) {
+ WriteVariableWidthUnsigned(this, val, chunk_length);
+}
+
+void BitWriterInterface::WriteVariableWidthU16(uint16_t val,
+ size_t chunk_length) {
+ WriteVariableWidthUnsigned(this, val, chunk_length);
+}
+
+void BitWriterInterface::WriteVariableWidthU8(uint8_t val,
+ size_t chunk_length) {
+ WriteVariableWidthUnsigned(this, val, chunk_length);
+}
+
+void BitWriterInterface::WriteVariableWidthS64(int64_t val,
+ size_t chunk_length,
+ size_t zigzag_exponent) {
+ WriteVariableWidthSigned(this, val, chunk_length, zigzag_exponent);
+}
+
+void BitWriterInterface::WriteVariableWidthS32(int32_t val,
+ size_t chunk_length,
+ size_t zigzag_exponent) {
+ WriteVariableWidthSigned(this, val, chunk_length, zigzag_exponent);
+}
+
+void BitWriterInterface::WriteVariableWidthS16(int16_t val,
+ size_t chunk_length,
+ size_t zigzag_exponent) {
+ WriteVariableWidthSigned(this, val, chunk_length, zigzag_exponent);
+}
+
+void BitWriterInterface::WriteVariableWidthS8(int8_t val,
+ size_t chunk_length,
+ size_t zigzag_exponent) {
+ WriteVariableWidthSigned(this, val, chunk_length, zigzag_exponent);
+}
+
+BitWriterWord64::BitWriterWord64(size_t reserve_bits) : end_(0) {
+ buffer_.reserve(NumBitsToNumWords<64>(reserve_bits));
+}
+
+void BitWriterWord64::WriteBits(uint64_t bits, size_t num_bits) {
+ // Check that |bits| and |num_bits| are valid and consistent.
+ assert(num_bits <= 64);
+ const bool is_little_endian = IsLittleEndian();
+ assert(is_little_endian && "Big-endian architecture support not implemented");
+ if (!is_little_endian) return;
+
+ bits = GetLowerBits(bits, num_bits);
+
+ // Offset from the start of the current word.
+ const size_t offset = end_ % 64;
+
+ if (offset == 0) {
+ // If no offset, simply add |bits| as a new word to the buffer_.
+ buffer_.push_back(bits);
+ } else {
+ // Shift bits and add them to the current word after offset.
+ const uint64_t first_word = bits << offset;
+ buffer_.back() |= first_word;
+
+ // If we don't overflow to the next word, there is nothing more to do.
+
+ if (offset + num_bits > 64) {
+ // We overflow to the next word.
+ const uint64_t second_word = bits >> (64 - offset);
+ // Add remaining bits as a new word to buffer_.
+ buffer_.push_back(second_word);
+ }
+ }
+
+ // Move end_ into position for next write.
+ end_ += num_bits;
+ assert(buffer_.size() * 64 >= end_);
+}
+
+bool BitReaderInterface::ReadVariableWidthU64(uint64_t* val,
+ size_t chunk_length) {
+ return ReadVariableWidthUnsigned(this, val, chunk_length);
+}
+
+bool BitReaderInterface::ReadVariableWidthU32(uint32_t* val,
+ size_t chunk_length) {
+ return ReadVariableWidthUnsigned(this, val, chunk_length);
+}
+
+bool BitReaderInterface::ReadVariableWidthU16(uint16_t* val,
+ size_t chunk_length) {
+ return ReadVariableWidthUnsigned(this, val, chunk_length);
+}
+
+bool BitReaderInterface::ReadVariableWidthU8(uint8_t* val,
+ size_t chunk_length) {
+ return ReadVariableWidthUnsigned(this, val, chunk_length);
+}
+
+bool BitReaderInterface::ReadVariableWidthS64(int64_t* val,
+ size_t chunk_length,
+ size_t zigzag_exponent) {
+ return ReadVariableWidthSigned(this, val, chunk_length, zigzag_exponent);
+}
+
+bool BitReaderInterface::ReadVariableWidthS32(int32_t* val,
+ size_t chunk_length,
+ size_t zigzag_exponent) {
+ return ReadVariableWidthSigned(this, val, chunk_length, zigzag_exponent);
+}
+
+bool BitReaderInterface::ReadVariableWidthS16(int16_t* val,
+ size_t chunk_length,
+ size_t zigzag_exponent) {
+ return ReadVariableWidthSigned(this, val, chunk_length, zigzag_exponent);
+}
+
+bool BitReaderInterface::ReadVariableWidthS8(int8_t* val,
+ size_t chunk_length,
+ size_t zigzag_exponent) {
+ return ReadVariableWidthSigned(this, val, chunk_length, zigzag_exponent);
+}
+
+BitReaderWord64::BitReaderWord64(std::vector<uint64_t>&& buffer)
+ : buffer_(std::move(buffer)), pos_(0) {}
+
+BitReaderWord64::BitReaderWord64(const std::vector<uint8_t>& buffer)
+ : buffer_(ToBuffer64(buffer)), pos_(0) {}
+
+size_t BitReaderWord64::ReadBits(uint64_t* bits, size_t num_bits) {
+ assert(num_bits <= 64);
+ const bool is_little_endian = IsLittleEndian();
+ assert(is_little_endian && "Big-endian architecture support not implemented");
+ if (!is_little_endian) return 0;
+
+ if (ReachedEnd())
+ return 0;
+
+ // Index of the current word.
+ const size_t index = pos_ / 64;
+
+ // Bit position in the current word where we start reading.
+ const size_t offset = pos_ % 64;
+
+ // Read all bits from the current word (it might be too much, but
+ // excessive bits will be removed later).
+ *bits = buffer_[index] >> offset;
+
+ const size_t num_read_from_first_word = std::min(64 - offset, num_bits);
+ pos_ += num_read_from_first_word;
+
+ if (pos_ >= buffer_.size() * 64) {
+ // Reached end of buffer_.
+ return num_read_from_first_word;
+ }
+
+ if (offset + num_bits > 64) {
+ // Requested |num_bits| overflows to next word.
+ // Write all bits from the beginning of next word to *bits after offset.
+ *bits |= buffer_[index + 1] << (64 - offset);
+ pos_ += offset + num_bits - 64;
+ }
+
+ // We likely have written more bits than requested. Clear excessive bits.
+ *bits = GetLowerBits(*bits, num_bits);
+ return num_bits;
+}
+
+bool BitReaderWord64::ReachedEnd() const {
+ return pos_ >= buffer_.size() * 64;
+}
+
+bool BitReaderWord64::OnlyZeroesLeft() const {
+ if (ReachedEnd())
+ return true;
+
+ const size_t index = pos_ / 64;
+ if (index < buffer_.size() - 1)
+ return false;
+
+ assert(index == buffer_.size() - 1);
+
+ const size_t offset = pos_ % 64;
+ const uint64_t remaining_bits = buffer_[index] >> offset;
+ return !remaining_bits;
+}
+
+} // namespace spvutils
diff --git a/source/util/bit_stream.h b/source/util/bit_stream.h
new file mode 100644
index 00000000..a139b633
--- /dev/null
+++ b/source/util/bit_stream.h
@@ -0,0 +1,378 @@
+// Copyright (c) 2017 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Contains utils for reading, writing and debug printing bit streams.
+
+#ifndef LIBSPIRV_UTIL_BIT_STREAM_H_
+#define LIBSPIRV_UTIL_BIT_STREAM_H_
+
+#include <bitset>
+#include <cstdint>
+#include <string>
+#include <sstream>
+#include <vector>
+
+namespace spvutils {
+
+// Terminology:
+// Bits - usually used for a uint64 word, first bit is the lowest.
+// Stream - std::string of '0' and '1', read left-to-right,
+// i.e. first bit is at the front and not at the end as in
+// std::bitset::to_string().
+// Bitset - std::bitset corresponding to uint64 bits and to reverse(stream).
+
+// Converts number of bits to a respective number of chunks of size N.
+// For example NumBitsToNumWords<8> returns how many bytes are needed to store
+// |num_bits|.
+template <size_t N>
+inline size_t NumBitsToNumWords(size_t num_bits) {
+ return (num_bits + (N - 1)) / N;
+}
+
+// Returns value of the same type as |in|, where all but the first |num_bits|
+// are set to zero.
+template <typename T>
+inline T GetLowerBits(T in, size_t num_bits) {
+ return sizeof(T) * 8 == num_bits ? in : in & T((T(1) << num_bits) - T(1));
+}
+
+// Encodes signed integer as unsigned in zigzag order:
+// 0 -> 0
+// -1 -> 1
+// 1 -> 2
+// -2 -> 3
+// 2 -> 4
+// Motivation: -1 is 0xFF...FF what doesn't work very well with
+// WriteVariableWidth which prefers to have as many 0 bits as possible.
+inline uint64_t EncodeZigZag(int64_t val) {
+ return (val << 1) ^ (val >> 63);
+}
+
+// Decodes signed integer encoded with EncodeZigZag.
+inline int64_t DecodeZigZag(uint64_t val) {
+ if (val & 1) {
+ // Negative.
+ // 1 -> -1
+ // 3 -> -2
+ // 5 -> -3
+ return -1 - (val >> 1);
+ } else {
+ // Non-negative.
+ // 0 -> 0
+ // 2 -> 1
+ // 4 -> 2
+ return val >> 1;
+ }
+}
+
+// Encodes signed integer as unsigned. This is a generalized version of
+// EncodeZigZag, designed to favor small positive numbers.
+// Values are transformed in blocks of 2^|block_exponent|.
+// If |block_exponent| is zero, then this degenerates into normal EncodeZigZag.
+// Example when |block_exponent| is 1 (return value is the index):
+// 0, 1, -1, -2, 2, 3, -3, -4, 4, 5, -5, -6, 6, 7, -7, -8
+// Example when |block_exponent| is 2:
+// 0, 1, 2, 3, -1, -2, -3, -4, 4, 5, 6, 7, -5, -6, -7, -8
+inline uint64_t EncodeZigZag(int64_t val, size_t block_exponent) {
+ assert(block_exponent < 64);
+ const uint64_t uval = static_cast<uint64_t>(val >= 0 ? val : -val - 1);
+ const uint64_t block_num = ((uval >> block_exponent) << 1) + (val >= 0 ? 0 : 1);
+ const uint64_t pos = GetLowerBits(uval, block_exponent);
+ return (block_num << block_exponent) + pos;
+}
+
+// Decodes signed integer encoded with EncodeZigZag. |block_exponent| must be
+// the same.
+inline int64_t DecodeZigZag(uint64_t val, size_t block_exponent) {
+ assert(block_exponent < 64);
+ const uint64_t block_num = val >> block_exponent;
+ const uint64_t pos = GetLowerBits(val, block_exponent);
+ if (block_num & 1) {
+ // Negative.
+ return -1LL - ((block_num >> 1) << block_exponent) - pos;
+ } else {
+ // Positive.
+ return ((block_num >> 1) << block_exponent) + pos;
+ }
+}
+
+// Converts |buffer| to a stream of '0' and '1'.
+template <typename T>
+std::string BufferToStream(const std::vector<T>& buffer) {
+ std::stringstream ss;
+ for (auto it = buffer.begin(); it != buffer.end(); ++it) {
+ std::string str = std::bitset<sizeof(T) * 8>(*it).to_string();
+ // Strings generated by std::bitset::to_string are read right to left.
+ // Reversing to left to right.
+ std::reverse(str.begin(), str.end());
+ ss << str;
+ }
+ return ss.str();
+}
+
+// Converts a left-to-right input string of '0' and '1' to a buffer of |T|
+// words.
+template <typename T>
+std::vector<T> StreamToBuffer(std::string str) {
+ // The input string is left-to-right, the input argument of std::bitset needs
+ // to right-to-left. Instead of reversing tokens, reverse the entire string
+ // and iterate tokens from end to begin.
+ std::reverse(str.begin(), str.end());
+ const int word_size = static_cast<int>(sizeof(T) * 8);
+ const int str_length = static_cast<int>(str.length());
+ std::vector<T> buffer;
+ buffer.reserve(NumBitsToNumWords<sizeof(T)>(str.length()));
+ for (int index = str_length - word_size; index >= 0; index -= word_size) {
+ buffer.push_back(static_cast<T>(std::bitset<sizeof(T) * 8>(
+ str, index, word_size).to_ullong()));
+ }
+ const size_t suffix_length = str.length() % word_size;
+ if (suffix_length != 0) {
+ buffer.push_back(static_cast<T>(std::bitset<sizeof(T) * 8>(
+ str, 0, suffix_length).to_ullong()));
+ }
+ return buffer;
+}
+
+// Adds '0' chars at the end of the string until the size is a multiple of N.
+template <size_t N>
+inline std::string PadToWord(std::string&& str) {
+ const size_t tail_length = str.size() % N;
+ if (tail_length != 0)
+ str += std::string(N - tail_length, '0');
+ return str;
+}
+
+// Adds '0' chars at the end of the string until the size is a multiple of N.
+template <size_t N>
+inline std::string PadToWord(const std::string& str) {
+ return PadToWord<N>(std::string(str));
+}
+
+// Converts a left-to-right stream of bits to std::bitset.
+template <size_t N>
+inline std::bitset<N> StreamToBitset(std::string str) {
+ std::reverse(str.begin(), str.end());
+ return std::bitset<N>(str);
+}
+
+// Converts first |num_bits| of std::bitset to a left-to-right stream of bits.
+template <size_t N>
+inline std::string BitsetToStream(const std::bitset<N>& bits, size_t num_bits = N) {
+ std::string str = bits.to_string().substr(N - num_bits);
+ std::reverse(str.begin(), str.end());
+ return str;
+}
+
+// Converts a left-to-right stream of bits to uint64.
+inline uint64_t StreamToBits(std::string str) {
+ std::reverse(str.begin(), str.end());
+ return std::bitset<64>(str).to_ullong();
+}
+
+// Converts first |num_bits| stored in uint64 to a left-to-right stream of bits.
+inline std::string BitsToStream(uint64_t bits, size_t num_bits = 64) {
+ std::bitset<64> bitset(bits);
+ return BitsetToStream(bitset, num_bits);
+}
+
+// Base class for writing sequences of bits.
+class BitWriterInterface {
+ public:
+ BitWriterInterface() {}
+ virtual ~BitWriterInterface() {}
+
+ // Writes lower |num_bits| in |bits| to the stream.
+ // |num_bits| must be no greater than 64.
+ virtual void WriteBits(uint64_t bits, size_t num_bits) = 0;
+
+ // Writes left-to-right string of '0' and '1' to stream.
+ // String length must be no greater than 64.
+ // Note: "01" will be writen as 0x2, not 0x1. The string doesn't represent
+ // numbers but a stream of bits in the order they come from encoder.
+ virtual void WriteStream(const std::string& bits) {
+ WriteBits(StreamToBits(bits), bits.length());
+ }
+
+ // Writes lower |num_bits| in |bits| to the stream.
+ // |num_bits| must be no greater than 64.
+ template <size_t N>
+ void WriteBitset(const std::bitset<N>& bits, size_t num_bits = N) {
+ WriteBits(bits.to_ullong(), num_bits);
+ }
+
+ // Writes |val| in chunks of size |chunk_length| followed by a signal bit:
+ // 0 - no more chunks to follow
+ // 1 - more chunks to follow
+ // for example 255 is encoded into 1111 1 1111 0 for chunk length 4.
+ // The last chunk can be truncated and signal bit omitted, if the entire
+ // payload (for example 16 bit for uint16_t has already been written).
+ void WriteVariableWidthU64(uint64_t val, size_t chunk_length);
+ void WriteVariableWidthU32(uint32_t val, size_t chunk_length);
+ void WriteVariableWidthU16(uint16_t val, size_t chunk_length);
+ void WriteVariableWidthU8(uint8_t val, size_t chunk_length);
+ void WriteVariableWidthS64(
+ int64_t val, size_t chunk_length, size_t zigzag_exponent);
+ void WriteVariableWidthS32(
+ int32_t val, size_t chunk_length, size_t zigzag_exponent);
+ void WriteVariableWidthS16(
+ int16_t val, size_t chunk_length, size_t zigzag_exponent);
+ void WriteVariableWidthS8(
+ int8_t val, size_t chunk_length, size_t zigzag_exponent);
+
+ // Returns number of bits written.
+ virtual size_t GetNumBits() const = 0;
+
+ // Provides direct access to the buffer data if implemented.
+ virtual const uint8_t* GetData() const {
+ return nullptr;
+ }
+
+ // Returns buffer size in bytes.
+ size_t GetDataSizeBytes() const {
+ return NumBitsToNumWords<8>(GetNumBits());
+ }
+
+ // Generates and returns byte array containing written bits.
+ virtual std::vector<uint8_t> GetDataCopy() const = 0;
+
+ BitWriterInterface(const BitWriterInterface&) = delete;
+ BitWriterInterface& operator=(const BitWriterInterface&) = delete;
+};
+
+// This class is an implementation of BitWriterInterface, using
+// std::vector<uint64_t> to store written bits.
+class BitWriterWord64 : public BitWriterInterface {
+ public:
+ explicit BitWriterWord64(size_t reserve_bits = 64);
+
+ void WriteBits(uint64_t bits, size_t num_bits) override;
+
+ size_t GetNumBits() const override {
+ return end_;
+ }
+
+ const uint8_t* GetData() const override {
+ return reinterpret_cast<const uint8_t*>(buffer_.data());
+ }
+
+ std::vector<uint8_t> GetDataCopy() const override {
+ return std::vector<uint8_t>(GetData(), GetData() + GetDataSizeBytes());
+ }
+
+ // Returns written stream as std::string, padded with zeroes so that the
+ // length is a multiple of 64.
+ std::string GetStreamPadded64() const {
+ return BufferToStream(buffer_);
+ }
+
+ private:
+ std::vector<uint64_t> buffer_;
+ // Total number of bits written so far. Named 'end' as analogy to std::end().
+ size_t end_;
+};
+
+// Base class for reading sequences of bits.
+class BitReaderInterface {
+ public:
+ BitReaderInterface() {}
+ virtual ~BitReaderInterface() {}
+
+ // Reads |num_bits| from the stream, stores them in |bits|.
+ // Returns number of read bits. |num_bits| must be no greater than 64.
+ virtual size_t ReadBits(uint64_t* bits, size_t num_bits) = 0;
+
+ // Reads |num_bits| from the stream, stores them in |bits|.
+ // Returns number of read bits. |num_bits| must be no greater than 64.
+ template <size_t N>
+ size_t ReadBitset(std::bitset<N>* bits, size_t num_bits = N) {
+ uint64_t val = 0;
+ size_t num_read = ReadBits(&val, num_bits);
+ if (num_read) {
+ *bits = std::bitset<N>(val);
+ }
+ return num_read;
+ }
+
+ // Reads |num_bits| from the stream, returns string in left-to-right order.
+ // The length of the returned string may be less than |num_bits| if end was
+ // reached.
+ std::string ReadStream(size_t num_bits) {
+ uint64_t bits = 0;
+ size_t num_read = ReadBits(&bits, num_bits);
+ return BitsToStream(bits, num_read);
+ }
+
+ // These two functions define 'hard' and 'soft' EOF.
+ //
+ // Returns true if the end of the buffer was reached.
+ virtual bool ReachedEnd() const = 0;
+ // Returns true if we reached the end of the buffer or are nearing it and only
+ // zero bits are left to read. Implementations of this function are allowed to
+ // commit a "false negative" error if the end of the buffer was not reached,
+ // i.e. it can return false even if indeed only zeroes are left.
+ // It is assumed that the consumer expects that
+ // the buffer stream ends with padding zeroes, and would accept this as a
+ // 'soft' EOF. Implementations of this class do not necessarily need to
+ // implement this, default behavior can simply delegate to ReachedEnd().
+ virtual bool OnlyZeroesLeft() const {
+ return ReachedEnd();
+ }
+
+ // Reads value encoded with WriteVariableWidthXXX (see BitWriterInterface).
+ // Reader and writer must use the same |chunk_length| and variable type.
+ // Returns true on success, false if the bit stream ends prematurely.
+ bool ReadVariableWidthU64(uint64_t* val, size_t chunk_length);
+ bool ReadVariableWidthU32(uint32_t* val, size_t chunk_length);
+ bool ReadVariableWidthU16(uint16_t* val, size_t chunk_length);
+ bool ReadVariableWidthU8(uint8_t* val, size_t chunk_length);
+ bool ReadVariableWidthS64(
+ int64_t* val, size_t chunk_length, size_t zigzag_exponent);
+ bool ReadVariableWidthS32(
+ int32_t* val, size_t chunk_length, size_t zigzag_exponent);
+ bool ReadVariableWidthS16(
+ int16_t* val, size_t chunk_length, size_t zigzag_exponent);
+ bool ReadVariableWidthS8(
+ int8_t* val, size_t chunk_length, size_t zigzag_exponent);
+
+ BitReaderInterface(const BitReaderInterface&) = delete;
+ BitReaderInterface& operator=(const BitReaderInterface&) = delete;
+};
+
+// This class is an implementation of BitReaderInterface which accepts both
+// uint8_t and uint64_t buffers as input. uint64_t buffers are consumed and
+// owned. uint8_t buffers are copied.
+class BitReaderWord64 : public BitReaderInterface {
+ public:
+ // Consumes and owns the buffer.
+ explicit BitReaderWord64(std::vector<uint64_t>&& buffer);
+
+ // Copies the buffer and casts it to uint64.
+ // Consuming the original buffer and casting it to uint64 is difficult,
+ // as it would potentially cause data misalignment and poor performance.
+ explicit BitReaderWord64(const std::vector<uint8_t>& buffer);
+
+ size_t ReadBits(uint64_t* bits, size_t num_bits) override;
+ bool ReachedEnd() const override;
+ bool OnlyZeroesLeft() const override;
+
+ BitReaderWord64() = delete;
+ private:
+ const std::vector<uint64_t> buffer_;
+ size_t pos_;
+};
+
+} // namespace spvutils
+
+#endif // LIBSPIRV_UTIL_BIT_STREAM_H_
diff --git a/source/util/string_utils.cpp b/source/util/string_utils.cpp
index d70c0b91..830f1a3b 100644
--- a/source/util/string_utils.cpp
+++ b/source/util/string_utils.cpp
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include <algorithm>
#include <cstdint>
#include <type_traits>
@@ -36,4 +37,3 @@ std::string CardinalToOrdinal(size_t cardinal) {
}
} // namespace spvutils
-
diff --git a/source/val/basic_block.h b/source/val/basic_block.h
index 29d730cb..81f0f666 100644
--- a/source/val/basic_block.h
+++ b/source/val/basic_block.h
@@ -15,7 +15,7 @@
#ifndef LIBSPIRV_VAL_BASICBLOCK_H_
#define LIBSPIRV_VAL_BASICBLOCK_H_
-#include "spirv/1.1/spirv.h"
+#include "spirv/1.2/spirv.h"
#include <cstdint>
diff --git a/source/val/function.cpp b/source/val/function.cpp
index a3c2215f..42f7ec15 100644
--- a/source/val/function.cpp
+++ b/source/val/function.cpp
@@ -24,6 +24,7 @@
#include "val/basic_block.h"
#include "val/construct.h"
#include "validate.h"
+#include "cfa.h"
using std::ignore;
using std::list;
@@ -32,55 +33,6 @@ using std::pair;
using std::tie;
using std::vector;
-namespace {
-
-using libspirv::BasicBlock;
-
-// Computes a minimal set of root nodes required to traverse, in the forward
-// direction, the CFG represented by the given vector of blocks, and successor
-// and predecessor functions. When considering adding two nodes, each having
-// predecessors, favour using the one that appears earlier on the input blocks
-// list.
-std::vector<BasicBlock*> TraversalRoots(const std::vector<BasicBlock*>& blocks,
- libspirv::get_blocks_func succ_func,
- libspirv::get_blocks_func pred_func) {
- // The set of nodes which have been visited from any of the roots so far.
- std::unordered_set<const BasicBlock*> visited;
-
- auto mark_visited = [&visited](const BasicBlock* b) { visited.insert(b); };
- auto ignore_block = [](const BasicBlock*) {};
- auto ignore_blocks = [](const BasicBlock*, const BasicBlock*) {};
-
- auto traverse_from_root = [&mark_visited, &succ_func, &ignore_block,
- &ignore_blocks](const BasicBlock* entry) {
- DepthFirstTraversal(entry, succ_func, mark_visited, ignore_block,
- ignore_blocks);
- };
-
- std::vector<BasicBlock*> result;
-
- // First collect nodes without predecessors.
- for (auto block : blocks) {
- if (pred_func(block)->empty()) {
- assert(visited.count(block) == 0 && "Malformed graph!");
- result.push_back(block);
- traverse_from_root(block);
- }
- }
-
- // Now collect other stranded nodes. These must be in unreachable cycles.
- for (auto block : blocks) {
- if (visited.count(block) == 0) {
- result.push_back(block);
- traverse_from_root(block);
- }
- }
-
- return result;
-}
-
-} // anonymous namespace
-
namespace libspirv {
// Universal Limit of ResultID + 1
@@ -322,42 +274,14 @@ void Function::ComputeAugmentedCFG() {
// the predecessors of the pseudo exit block.
auto succ_func = [](const BasicBlock* b) { return b->successors(); };
auto pred_func = [](const BasicBlock* b) { return b->predecessors(); };
- auto sources = TraversalRoots(ordered_blocks_, succ_func, pred_func);
-
- // For the predecessor traversals, reverse the order of blocks. This
- // will affect the post-dominance calculation as follows:
- // - Suppose you have blocks A and B, with A appearing before B in
- // the list of blocks.
- // - Also, A branches only to B, and B branches only to A.
- // - We want to compute A as dominating B, and B as post-dominating B.
- // By using reversed blocks for predecessor traversal roots discovery,
- // we'll add an edge from B to the pseudo-exit node, rather than from A.
- // All this is needed to correctly process the dominance/post-dominance
- // constraint when A is a loop header that points to itself as its
- // own continue target, and B is the latch block for the loop.
- std::vector<BasicBlock*> reversed_blocks(ordered_blocks_.rbegin(),
- ordered_blocks_.rend());
- auto sinks = TraversalRoots(reversed_blocks, pred_func, succ_func);
-
- // Wire up the pseudo entry block.
- augmented_successors_map_[&pseudo_entry_block_] = sources;
- for (auto block : sources) {
- auto& augmented_preds = augmented_predecessors_map_[block];
- const auto& preds = *block->predecessors();
- augmented_preds.reserve(1 + preds.size());
- augmented_preds.push_back(&pseudo_entry_block_);
- augmented_preds.insert(augmented_preds.end(), preds.begin(), preds.end());
- }
-
- // Wire up the pseudo exit block.
- augmented_predecessors_map_[&pseudo_exit_block_] = sinks;
- for (auto block : sinks) {
- auto& augmented_succ = augmented_successors_map_[block];
- const auto& succ = *block->successors();
- augmented_succ.reserve(1 + succ.size());
- augmented_succ.push_back(&pseudo_exit_block_);
- augmented_succ.insert(augmented_succ.end(), succ.begin(), succ.end());
- }
+ spvtools::CFA<BasicBlock>::ComputeAugmentedCFG(
+ ordered_blocks_,
+ &pseudo_entry_block_,
+ &pseudo_exit_block_,
+ &augmented_successors_map_,
+ &augmented_predecessors_map_,
+ succ_func,
+ pred_func);
};
Construct& Function::AddConstruct(const Construct& new_construct) {
diff --git a/source/val/function.h b/source/val/function.h
index f5087f71..7eb8dcdf 100644
--- a/source/val/function.h
+++ b/source/val/function.h
@@ -22,7 +22,7 @@
#include <vector>
#include "spirv-tools/libspirv.h"
-#include "spirv/1.1/spirv.h"
+#include "spirv/1.2/spirv.h"
#include "val/basic_block.h"
#include "val/construct.h"
diff --git a/source/val/instruction.h b/source/val/instruction.h
index 1d8fe91a..31b463a6 100644
--- a/source/val/instruction.h
+++ b/source/val/instruction.h
@@ -15,8 +15,8 @@
#ifndef LIBSPIRV_VAL_INSTRUCTION_H_
#define LIBSPIRV_VAL_INSTRUCTION_H_
+#include <cassert>
#include <cstdint>
-
#include <functional>
#include <utility>
#include <vector>
@@ -71,6 +71,20 @@ class Instruction {
return operands_;
}
+ /// Provides direct access to the stored C instruction object.
+ const spv_parsed_instruction_t& c_inst() const {
+ return inst_;
+ }
+
+ // Casts the words belonging to the operand under |index| to |T| and returns.
+ template <typename T>
+ T GetOperandAs(size_t index) const {
+ const spv_parsed_operand_t& operand = operands_.at(index);
+ assert(operand.num_words * 4 >= sizeof(T));
+ assert(operand.offset + operand.num_words <= inst_.num_words);
+ return *reinterpret_cast<const T*>(&words_[operand.offset]);
+ }
+
private:
const std::vector<uint32_t> words_;
const std::vector<spv_parsed_operand_t> operands_;
diff --git a/source/val/validation_state.cpp b/source/val/validation_state.cpp
index 12e864d3..350b09df 100644
--- a/source/val/validation_state.cpp
+++ b/source/val/validation_state.cpp
@@ -300,6 +300,14 @@ void ValidationState_t::RegisterCapability(SpvCapability cap) {
features_.declare_int16_type = true;
features_.declare_float16_type = true;
features_.free_fp_rounding_mode = true;
+ break;
+ case SpvCapabilityVariablePointers:
+ features_.variable_pointers = true;
+ features_.variable_pointers_storage_buffer = true;
+ break;
+ case SpvCapabilityVariablePointersStorageBuffer:
+ features_.variable_pointers_storage_buffer = true;
+ break;
default:
break;
}
diff --git a/source/val/validation_state.h b/source/val/validation_state.h
index a61fbabb..87a80ce4 100644
--- a/source/val/validation_state.h
+++ b/source/val/validation_state.h
@@ -27,7 +27,7 @@
#include "diagnostic.h"
#include "enum_set.h"
#include "spirv-tools/libspirv.h"
-#include "spirv/1.1/spirv.h"
+#include "spirv/1.2/spirv.h"
#include "spirv_definition.h"
#include "val/function.h"
#include "val/instruction.h"
@@ -62,6 +62,12 @@ class ValidationState_t {
bool free_fp_rounding_mode = false; // Allow the FPRoundingMode decoration
// and its vaules to be used without
// requiring any capability
+
+ // Allow functionalities enabled by VariablePointers capability.
+ bool variable_pointers = false;
+ // Allow functionalities enabled by VariablePointersStorageBuffer
+ // capability.
+ bool variable_pointers_storage_buffer = false;
};
ValidationState_t(const spv_const_context context,
diff --git a/source/validate.cpp b/source/validate.cpp
index ccfce05c..03bc6d9a 100644
--- a/source/validate.cpp
+++ b/source/validate.cpp
@@ -20,6 +20,7 @@
#include <algorithm>
#include <functional>
#include <iterator>
+#include <memory>
#include <sstream>
#include <string>
#include <vector>
@@ -236,18 +237,19 @@ spv_result_t spvValidate(const spv_const_context context,
spv_result_t ValidateBinaryUsingContextAndValidationState(
const spv_context_t& context, const uint32_t* words, const size_t num_words,
spv_diagnostic* pDiagnostic, ValidationState_t* vstate) {
- spv_const_binary binary = new spv_const_binary_t{words, num_words};
+ auto binary = std::unique_ptr<spv_const_binary_t>(
+ new spv_const_binary_t{words, num_words});
spv_endianness_t endian;
spv_position_t position = {};
- if (spvBinaryEndianness(binary, &endian)) {
+ if (spvBinaryEndianness(binary.get(), &endian)) {
return libspirv::DiagnosticStream(position, context.consumer,
SPV_ERROR_INVALID_BINARY)
<< "Invalid SPIR-V magic number.";
}
spv_header_t header;
- if (spvBinaryHeaderGet(binary, endian, &header)) {
+ if (spvBinaryHeaderGet(binary.get(), endian, &header)) {
return libspirv::DiagnosticStream(position, context.consumer,
SPV_ERROR_INVALID_BINARY)
<< "Invalid SPIR-V header.";
@@ -374,7 +376,9 @@ spv_result_t spvValidateWithOptions(const spv_const_context context,
hijack_context, binary->code, binary->wordCount, pDiagnostic, &vstate);
}
-spv_result_t spvtools::ValidateBinaryAndKeepValidationState(
+namespace spvtools {
+
+spv_result_t ValidateBinaryAndKeepValidationState(
const spv_const_context context, spv_const_validator_options options,
const uint32_t* words, const size_t num_words, spv_diagnostic* pDiagnostic,
std::unique_ptr<ValidationState_t>* vstate) {
@@ -390,3 +394,9 @@ spv_result_t spvtools::ValidateBinaryAndKeepValidationState(
hijack_context, words, num_words, pDiagnostic, vstate->get());
}
+spv_result_t ValidateInstructionAndUpdateValidationState(
+ ValidationState_t* vstate, const spv_parsed_instruction_t* inst) {
+ return ProcessInstruction(vstate, inst);
+}
+
+} // namespace spvtools
diff --git a/source/validate.h b/source/validate.h
index b6d05168..34d1ffef 100644
--- a/source/validate.h
+++ b/source/validate.h
@@ -34,55 +34,6 @@ class BasicBlock;
using get_blocks_func =
std::function<const std::vector<BasicBlock*>*(const BasicBlock*)>;
-/// @brief Depth first traversal starting from the \p entry BasicBlock
-///
-/// This function performs a depth first traversal from the \p entry
-/// BasicBlock and calls the pre/postorder functions when it needs to process
-/// the node in pre order, post order. It also calls the backedge function
-/// when a back edge is encountered.
-///
-/// @param[in] entry The root BasicBlock of a CFG
-/// @param[in] successor_func A function which will return a pointer to the
-/// successor nodes
-/// @param[in] preorder A function that will be called for every block in a
-/// CFG following preorder traversal semantics
-/// @param[in] postorder A function that will be called for every block in a
-/// CFG following postorder traversal semantics
-/// @param[in] backedge A function that will be called when a backedge is
-/// encountered during a traversal
-/// NOTE: The @p successor_func and predecessor_func each return a pointer to a
-/// collection such that iterators to that collection remain valid for the
-/// lifetime of the algorithm.
-void DepthFirstTraversal(
- const BasicBlock* entry, get_blocks_func successor_func,
- std::function<void(const BasicBlock*)> preorder,
- std::function<void(const BasicBlock*)> postorder,
- std::function<void(const BasicBlock*, const BasicBlock*)> backedge);
-
-/// @brief Calculates dominator edges for a set of blocks
-///
-/// Computes dominators using the algorithm of Cooper, Harvey, and Kennedy
-/// "A Simple, Fast Dominance Algorithm", 2001.
-///
-/// The algorithm assumes there is a unique root node (a node without
-/// predecessors), and it is therefore at the end of the postorder vector.
-///
-/// This function calculates the dominator edges for a set of blocks in the CFG.
-/// Uses the dominator algorithm by Cooper et al.
-///
-/// @param[in] postorder A vector of blocks in post order traversal order
-/// in a CFG
-/// @param[in] predecessor_func Function used to get the predecessor nodes of a
-/// block
-///
-/// @return the dominator tree of the graph, as a vector of pairs of nodes.
-/// The first node in the pair is a node in the graph. The second node in the
-/// pair is its immediate dominator in the sense of Cooper et.al., where a block
-/// without predecessors (such as the root node) is its own immediate dominator.
-std::vector<std::pair<BasicBlock*, BasicBlock*>> CalculateDominators(
- const std::vector<const BasicBlock*>& postorder,
- get_blocks_func predecessor_func);
-
/// @brief Performs the Control Flow Graph checks
///
/// @param[in] _ the validation state of the module
@@ -213,6 +164,12 @@ spv_result_t ValidateBinaryAndKeepValidationState(
const spv_const_context context, spv_const_validator_options options,
const uint32_t* words, const size_t num_words, spv_diagnostic* pDiagnostic,
std::unique_ptr<libspirv::ValidationState_t>* vstate);
+
+// Performs validation for a single instruction and updates given validation
+// state.
+spv_result_t ValidateInstructionAndUpdateValidationState(
+ libspirv::ValidationState_t* vstate, const spv_parsed_instruction_t* inst);
+
} // namespace spvtools
#endif // LIBSPIRV_VALIDATE_H_
diff --git a/source/validate_cfg.cpp b/source/validate_cfg.cpp
index a41101f0..9234542e 100644
--- a/source/validate_cfg.cpp
+++ b/source/validate_cfg.cpp
@@ -13,6 +13,7 @@
// limitations under the License.
#include "validate.h"
+#include "cfa.h"
#include <algorithm>
#include <cassert>
@@ -58,132 +59,8 @@ using bb_ptr = BasicBlock*;
using cbb_ptr = const BasicBlock*;
using bb_iter = vector<BasicBlock*>::const_iterator;
-struct block_info {
- cbb_ptr block; ///< pointer to the block
- bb_iter iter; ///< Iterator to the current child node being processed
-};
-
-/// Returns true if a block with @p id is found in the @p work_list vector
-///
-/// @param[in] work_list Set of blocks visited in the the depth first traversal
-/// of the CFG
-/// @param[in] id The ID of the block being checked
-///
-/// @return true if the edge work_list.back().block->id() => id is a back-edge
-bool FindInWorkList(const vector<block_info>& work_list, uint32_t id) {
- for (const auto b : work_list) {
- if (b.block->id() == id) return true;
- }
- return false;
-}
-
} // namespace
-void DepthFirstTraversal(const BasicBlock* entry,
- get_blocks_func successor_func,
- function<void(cbb_ptr)> preorder,
- function<void(cbb_ptr)> postorder,
- function<void(cbb_ptr, cbb_ptr)> backedge) {
- unordered_set<uint32_t> processed;
-
- /// NOTE: work_list is the sequence of nodes from the root node to the node
- /// being processed in the traversal
- vector<block_info> work_list;
- work_list.reserve(10);
-
- work_list.push_back({entry, begin(*successor_func(entry))});
- preorder(entry);
- processed.insert(entry->id());
-
- while (!work_list.empty()) {
- block_info& top = work_list.back();
- if (top.iter == end(*successor_func(top.block))) {
- postorder(top.block);
- work_list.pop_back();
- } else {
- BasicBlock* child = *top.iter;
- top.iter++;
- if (FindInWorkList(work_list, child->id())) {
- backedge(top.block, child);
- }
- if (processed.count(child->id()) == 0) {
- preorder(child);
- work_list.emplace_back(
- block_info{child, begin(*successor_func(child))});
- processed.insert(child->id());
- }
- }
- }
-}
-
-vector<pair<BasicBlock*, BasicBlock*>> CalculateDominators(
- const vector<cbb_ptr>& postorder, get_blocks_func predecessor_func) {
- struct block_detail {
- size_t dominator; ///< The index of blocks's dominator in post order array
- size_t postorder_index; ///< The index of the block in the post order array
- };
- const size_t undefined_dom = postorder.size();
-
- unordered_map<cbb_ptr, block_detail> idoms;
- for (size_t i = 0; i < postorder.size(); i++) {
- idoms[postorder[i]] = {undefined_dom, i};
- }
- idoms[postorder.back()].dominator = idoms[postorder.back()].postorder_index;
-
- bool changed = true;
- while (changed) {
- changed = false;
- for (auto b = postorder.rbegin() + 1; b != postorder.rend(); ++b) {
- const vector<BasicBlock*>& predecessors = *predecessor_func(*b);
- // Find the first processed/reachable predecessor that is reachable
- // in the forward traversal.
- auto res = find_if(begin(predecessors), end(predecessors),
- [&idoms, undefined_dom](BasicBlock* pred) {
- return idoms.count(pred) &&
- idoms[pred].dominator != undefined_dom;
- });
- if (res == end(predecessors)) continue;
- const BasicBlock* idom = *res;
- size_t idom_idx = idoms[idom].postorder_index;
-
- // all other predecessors
- for (const auto* p : predecessors) {
- if (idom == p) continue;
- // Only consider nodes reachable in the forward traversal.
- // Otherwise the intersection doesn't make sense and will never
- // terminate.
- if (!idoms.count(p)) continue;
- if (idoms[p].dominator != undefined_dom) {
- size_t finger1 = idoms[p].postorder_index;
- size_t finger2 = idom_idx;
- while (finger1 != finger2) {
- while (finger1 < finger2) {
- finger1 = idoms[postorder[finger1]].dominator;
- }
- while (finger2 < finger1) {
- finger2 = idoms[postorder[finger2]].dominator;
- }
- }
- idom_idx = finger1;
- }
- }
- if (idoms[*b].dominator != idom_idx) {
- idoms[*b].dominator = idom_idx;
- changed = true;
- }
- }
- }
-
- vector<pair<bb_ptr, bb_ptr>> out;
- for (auto idom : idoms) {
- // NOTE: performing a const cast for convenient usage with
- // UpdateImmediateDominators
- out.push_back({const_cast<BasicBlock*>(get<0>(idom)),
- const_cast<BasicBlock*>(postorder[get<1>(idom).dominator])});
- }
- return out;
-}
-
void printDominatorList(const BasicBlock& b) {
std::cout << b.id() << " is dominated by: ";
const BasicBlock* bb = &b;
@@ -406,28 +283,28 @@ spv_result_t PerformCfgChecks(ValidationState_t& _) {
auto ignore_edge = [](cbb_ptr, cbb_ptr) {};
if (!function.ordered_blocks().empty()) {
/// calculate dominators
- DepthFirstTraversal(
+ spvtools::CFA<libspirv::BasicBlock>::DepthFirstTraversal(
function.first_block(), function.AugmentedCFGSuccessorsFunction(),
ignore_block, [&](cbb_ptr b) { postorder.push_back(b); },
ignore_edge);
- auto edges = libspirv::CalculateDominators(
+ auto edges = spvtools::CFA<libspirv::BasicBlock>::CalculateDominators(
postorder, function.AugmentedCFGPredecessorsFunction());
for (auto edge : edges) {
edge.first->SetImmediateDominator(edge.second);
}
/// calculate post dominators
- DepthFirstTraversal(
+ spvtools::CFA<libspirv::BasicBlock>::DepthFirstTraversal(
function.pseudo_exit_block(),
function.AugmentedCFGPredecessorsFunction(), ignore_block,
[&](cbb_ptr b) { postdom_postorder.push_back(b); }, ignore_edge);
- auto postdom_edges = libspirv::CalculateDominators(
+ auto postdom_edges = spvtools::CFA<libspirv::BasicBlock>::CalculateDominators(
postdom_postorder, function.AugmentedCFGSuccessorsFunction());
for (auto edge : postdom_edges) {
edge.first->SetImmediatePostDominator(edge.second);
}
/// calculate back edges.
- DepthFirstTraversal(
+ spvtools::CFA<libspirv::BasicBlock>::DepthFirstTraversal(
function.pseudo_entry_block(),
function
.AugmentedCFGSuccessorsFunctionIncludingHeaderToContinueEdge(),
diff --git a/source/validate_id.cpp b/source/validate_id.cpp
index 039c9eb4..a9082c49 100644
--- a/source/validate_id.cpp
+++ b/source/validate_id.cpp
@@ -1116,10 +1116,17 @@ bool idUsage::isValid<SpvOpLoad>(const spv_instruction_t* inst,
<< inst->words[resultTypeIndex] << "' is not defind.";
return false;
}
+ const bool uses_variable_pointer =
+ module_.features().variable_pointers ||
+ module_.features().variable_pointers_storage_buffer;
auto pointerIndex = 3;
auto pointer = module_.FindDef(inst->words[pointerIndex]);
- if (!pointer || (addressingModel == SpvAddressingModelLogical &&
- !spvOpcodeReturnsLogicalPointer(pointer->opcode()))) {
+ if (!pointer ||
+ (addressingModel == SpvAddressingModelLogical &&
+ ((!uses_variable_pointer &&
+ !spvOpcodeReturnsLogicalPointer(pointer->opcode())) ||
+ (uses_variable_pointer &&
+ !spvOpcodeReturnsLogicalVariablePointer(pointer->opcode()))))) {
DIAG(pointerIndex) << "OpLoad Pointer <id> '" << inst->words[pointerIndex]
<< "' is not a logical pointer.";
return false;
@@ -1145,10 +1152,17 @@ bool idUsage::isValid<SpvOpLoad>(const spv_instruction_t* inst,
template <>
bool idUsage::isValid<SpvOpStore>(const spv_instruction_t* inst,
const spv_opcode_desc) {
- auto pointerIndex = 1;
+ const bool uses_variable_pointer =
+ module_.features().variable_pointers ||
+ module_.features().variable_pointers_storage_buffer;
+ const auto pointerIndex = 1;
auto pointer = module_.FindDef(inst->words[pointerIndex]);
- if (!pointer || (addressingModel == SpvAddressingModelLogical &&
- !spvOpcodeReturnsLogicalPointer(pointer->opcode()))) {
+ if (!pointer ||
+ (addressingModel == SpvAddressingModelLogical &&
+ ((!uses_variable_pointer &&
+ !spvOpcodeReturnsLogicalPointer(pointer->opcode())) ||
+ (uses_variable_pointer &&
+ !spvOpcodeReturnsLogicalVariablePointer(pointer->opcode()))))) {
DIAG(pointerIndex) << "OpStore Pointer <id> '" << inst->words[pointerIndex]
<< "' is not a logical pointer.";
return false;
@@ -2404,8 +2418,11 @@ bool idUsage::isValid<SpvOpReturnValue>(const spv_instruction_t* inst,
<< "' is missing or void.";
return false;
}
+ const bool uses_variable_pointer =
+ module_.features().variable_pointers ||
+ module_.features().variable_pointers_storage_buffer;
if (addressingModel == SpvAddressingModelLogical &&
- SpvOpTypePointer == valueType->opcode()) {
+ SpvOpTypePointer == valueType->opcode() && !uses_variable_pointer) {
DIAG(valueIndex)
<< "OpReturnValue value's type <id> '" << value->type_id()
<< "' is a pointer, which is invalid in the Logical addressing model.";
diff --git a/source/validate_instruction.cpp b/source/validate_instruction.cpp
index 49d21132..aafb7b04 100644
--- a/source/validate_instruction.cpp
+++ b/source/validate_instruction.cpp
@@ -395,9 +395,10 @@ spv_result_t InstructionPass(ValidationState_t& _,
const SpvOp opcode = static_cast<SpvOp>(inst->opcode);
if (opcode == SpvOpExtension)
CheckIfKnownExtension(_, inst);
- if (opcode == SpvOpCapability)
+ if (opcode == SpvOpCapability) {
_.RegisterCapability(
static_cast<SpvCapability>(inst->words[inst->operands[0].offset]));
+ }
if (opcode == SpvOpMemoryModel) {
_.set_addressing_model(
static_cast<SpvAddressingModel>(inst->words[inst->operands[0].offset]));
diff --git a/source/validate_type_unique.cpp b/source/validate_type_unique.cpp
index a7f4b9e6..98768dfa 100644
--- a/source/validate_type_unique.cpp
+++ b/source/validate_type_unique.cpp
@@ -26,9 +26,15 @@ namespace libspirv {
// Validates that type declarations are unique, unless multiple declarations
// of the same data type are allowed by the specification.
// (see section 2.8 Types and Variables)
+// Doesn't do anything if SPV_VAL_ignore_type_decl_unique was declared in the
+// module.
spv_result_t TypeUniquePass(ValidationState_t& _,
const spv_parsed_instruction_t* inst) {
+ if (_.HasExtension(Extension::kSPV_VALIDATOR_ignore_type_decl_unique))
+ return SPV_SUCCESS;
+
const SpvOp opcode = static_cast<SpvOp>(inst->opcode);
+
if (spvOpcodeGeneratesType(opcode)) {
if (opcode == SpvOpTypeArray || opcode == SpvOpTypeRuntimeArray ||
opcode == SpvOpTypeStruct) {
diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt
index 79c2684d..926dadef 100644
--- a/test/CMakeLists.txt
+++ b/test/CMakeLists.txt
@@ -59,6 +59,7 @@ function(add_spvtools_unittest)
target_link_libraries(${target} PRIVATE ${ARG_LIBS})
target_link_libraries(${target} PRIVATE gmock_main)
add_test(NAME spirv-tools-${target} COMMAND ${target})
+ set_property(TARGET ${target} PROPERTY FOLDER "SPIRV-Tools tests")
endif()
endfunction()
@@ -158,5 +159,16 @@ add_spvtools_unittest(
SRCS log_test.cpp
LIBS ${SPIRV_TOOLS})
+add_spvtools_unittest(
+ TARGET preserve_numeric_ids
+ SRCS preserve_numeric_ids_test.cpp
+ LIBS ${SPIRV_TOOLS})
+
+add_spvtools_unittest(
+ TARGET bit_stream
+ SRCS bit_stream.cpp
+ LIBS ${SPIRV_TOOLS})
+
add_subdirectory(opt)
add_subdirectory(val)
+add_subdirectory(stats)
diff --git a/test/binary_header_get_test.cpp b/test/binary_header_get_test.cpp
index 047d3a02..9ccd0a95 100644
--- a/test/binary_header_get_test.cpp
+++ b/test/binary_header_get_test.cpp
@@ -50,7 +50,7 @@ TEST_F(BinaryHeaderGet, Default) {
ASSERT_EQ(SPV_SUCCESS, spvBinaryHeaderGet(&const_bin, endian, &header));
ASSERT_EQ(static_cast<uint32_t>(SpvMagicNumber), header.magic);
- ASSERT_EQ(0x00010100u, header.version);
+ ASSERT_EQ(0x00010200u, header.version);
ASSERT_EQ(static_cast<uint32_t>(SPV_GENERATOR_CODEPLAY), header.generator);
ASSERT_EQ(1u, header.bound);
ASSERT_EQ(0u, header.schema);
diff --git a/test/binary_to_text_test.cpp b/test/binary_to_text_test.cpp
index ed1bee56..fcbf8835 100644
--- a/test/binary_to_text_test.cpp
+++ b/test/binary_to_text_test.cpp
@@ -235,10 +235,12 @@ TEST_P(RoundTripInstructionsTest, Sample) {
Eq(get<1>(GetParam())));
}
+// clang-format off
INSTANTIATE_TEST_CASE_P(
NumericLiterals, RoundTripInstructionsTest,
// This test is independent of environment, so just test the one.
- Combine(::testing::Values(SPV_ENV_UNIVERSAL_1_0),
+ Combine(::testing::Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1,
+ SPV_ENV_UNIVERSAL_1_2),
::testing::ValuesIn(std::vector<std::string>{
"%1 = OpTypeInt 12 0\n%2 = OpConstant %1 1867\n",
"%1 = OpTypeInt 12 1\n%2 = OpConstant %1 1867\n",
@@ -267,10 +269,12 @@ INSTANTIATE_TEST_CASE_P(
"%1 = OpTypeFloat 64\n%2 = OpConstant %1 0x1p+1024\n", // Inf
"%1 = OpTypeFloat 64\n%2 = OpConstant %1 -0x1p+1024\n", // -Inf
})), );
+// clang-format on
INSTANTIATE_TEST_CASE_P(
MemoryAccessMasks, RoundTripInstructionsTest,
- Combine(::testing::Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1),
+ Combine(::testing::Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1,
+ SPV_ENV_UNIVERSAL_1_2),
::testing::ValuesIn(std::vector<std::string>{
"OpStore %1 %2\n", // 3 words long.
"OpStore %1 %2 None\n", // 4 words long, explicit final 0.
@@ -285,7 +289,8 @@ INSTANTIATE_TEST_CASE_P(
INSTANTIATE_TEST_CASE_P(
FPFastMathModeMasks, RoundTripInstructionsTest,
Combine(
- ::testing::Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1),
+ ::testing::Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1,
+ SPV_ENV_UNIVERSAL_1_2),
::testing::ValuesIn(std::vector<std::string>{
"OpDecorate %1 FPFastMathMode None\n",
"OpDecorate %1 FPFastMathMode NotNaN\n",
@@ -301,7 +306,8 @@ INSTANTIATE_TEST_CASE_P(
INSTANTIATE_TEST_CASE_P(
LoopControlMasks, RoundTripInstructionsTest,
- Combine(::testing::Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1),
+ Combine(::testing::Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1,
+ SPV_ENV_UNIVERSAL_1_2),
::testing::ValuesIn(std::vector<std::string>{
"OpLoopMerge %1 %2 None\n", "OpLoopMerge %1 %2 Unroll\n",
"OpLoopMerge %1 %2 DontUnroll\n",
@@ -309,7 +315,8 @@ INSTANTIATE_TEST_CASE_P(
})), );
INSTANTIATE_TEST_CASE_P(LoopControlMasksV11, RoundTripInstructionsTest,
- Combine(::testing::Values(SPV_ENV_UNIVERSAL_1_1),
+ Combine(::testing::Values(SPV_ENV_UNIVERSAL_1_1,
+ SPV_ENV_UNIVERSAL_1_2),
::testing::ValuesIn(std::vector<std::string>{
"OpLoopMerge %1 %2 DependencyInfinite\n",
"OpLoopMerge %1 %2 DependencyLength 8\n",
@@ -317,7 +324,8 @@ INSTANTIATE_TEST_CASE_P(LoopControlMasksV11, RoundTripInstructionsTest,
INSTANTIATE_TEST_CASE_P(
SelectionControlMasks, RoundTripInstructionsTest,
- Combine(::testing::Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1),
+ Combine(::testing::Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1,
+ SPV_ENV_UNIVERSAL_1_2),
::testing::ValuesIn(std::vector<std::string>{
"OpSelectionMerge %1 None\n", "OpSelectionMerge %1 Flatten\n",
"OpSelectionMerge %1 DontFlatten\n",
@@ -326,7 +334,8 @@ INSTANTIATE_TEST_CASE_P(
INSTANTIATE_TEST_CASE_P(
FunctionControlMasks, RoundTripInstructionsTest,
- Combine(::testing::Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1),
+ Combine(::testing::Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1,
+ SPV_ENV_UNIVERSAL_1_2),
::testing::ValuesIn(std::vector<std::string>{
"%2 = OpFunction %1 None %3\n",
"%2 = OpFunction %1 Inline %3\n",
@@ -338,7 +347,8 @@ INSTANTIATE_TEST_CASE_P(
INSTANTIATE_TEST_CASE_P(
ImageMasks, RoundTripInstructionsTest,
- Combine(::testing::Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1),
+ Combine(::testing::Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1,
+ SPV_ENV_UNIVERSAL_1_2),
::testing::ValuesIn(std::vector<std::string>{
"%2 = OpImageFetch %1 %3 %4\n",
"%2 = OpImageFetch %1 %3 %4 None\n",
@@ -358,6 +368,17 @@ INSTANTIATE_TEST_CASE_P(
" Bias|Lod|Grad|ConstOffset|Offset|ConstOffsets|Sample|MinLod"
" %5 %6 %7 %8 %9 %10 %11 %12 %13\n"})), );
+INSTANTIATE_TEST_CASE_P(
+ NewInstructionsInSPIRV1_2, RoundTripInstructionsTest,
+ Combine(::testing::Values(SPV_ENV_UNIVERSAL_1_2),
+ ::testing::ValuesIn(std::vector<std::string>{
+ "OpExecutionModeId %1 SubgroupsPerWorkgroupId %2\n",
+ "OpExecutionModeId %1 LocalSizeId %2 %3 %4\n",
+ "OpExecutionModeId %1 LocalSizeHintId %2\n",
+ "OpDecorateId %1 AlignmentId %2\n",
+ "OpDecorateId %1 MaxByteOffsetId %2\n",
+ })), );
+
using MaskSorting = TextToBinaryTest;
TEST_F(MaskSorting, MasksAreSortedFromLSBToMSB) {
diff --git a/test/bit_stream.cpp b/test/bit_stream.cpp
new file mode 100644
index 00000000..8deeb4e8
--- /dev/null
+++ b/test/bit_stream.cpp
@@ -0,0 +1,1138 @@
+// Copyright (c) 2017 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <sstream>
+#include <string>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "util/bit_stream.h"
+
+namespace {
+
+using spvutils::BitWriterInterface;
+using spvutils::BitReaderInterface;
+using spvutils::BitWriterWord64;
+using spvutils::BitReaderWord64;
+using spvutils::StreamToBuffer;
+using spvutils::BufferToStream;
+using spvutils::NumBitsToNumWords;
+using spvutils::PadToWord;
+using spvutils::StreamToBitset;
+using spvutils::BitsetToStream;
+using spvutils::BitsToStream;
+using spvutils::StreamToBits;
+using spvutils::GetLowerBits;
+using spvutils::EncodeZigZag;
+using spvutils::DecodeZigZag;
+
+// A simple and inefficient implementatition of BitWriterInterface,
+// using std::stringstream. Intended for tests only.
+class BitWriterStringStream : public BitWriterInterface {
+ public:
+ void WriteStream(const std::string& bits) override {
+ ss_ << bits;
+ }
+
+ void WriteBits(uint64_t bits, size_t num_bits) override {
+ assert(num_bits <= 64);
+ ss_ << BitsToStream(bits, num_bits);
+ }
+
+ size_t GetNumBits() const override {
+ return ss_.str().size();
+ }
+
+ std::vector<uint8_t> GetDataCopy() const override {
+ return StreamToBuffer<uint8_t>(ss_.str());
+ }
+
+ std::string GetStreamRaw() const {
+ return ss_.str();
+ }
+
+ private:
+ std::stringstream ss_;
+};
+
+// A simple and inefficient implementatition of BitReaderInterface.
+// Intended for tests only.
+class BitReaderFromString : public BitReaderInterface {
+ public:
+ explicit BitReaderFromString(std::string&& str)
+ : str_(std::move(str)), pos_(0) {}
+
+ explicit BitReaderFromString(const std::vector<uint64_t>& buffer)
+ : str_(BufferToStream(buffer)), pos_(0) {}
+
+ explicit BitReaderFromString(const std::vector<uint8_t>& buffer)
+ : str_(PadToWord<64>(BufferToStream(buffer))), pos_(0) {}
+
+ size_t ReadBits(uint64_t* bits, size_t num_bits) override {
+ if (ReachedEnd())
+ return 0;
+ std::string sub = str_.substr(pos_, num_bits);
+ *bits = StreamToBits(sub);
+ pos_ += sub.length();
+ return sub.length();
+ }
+
+ bool ReachedEnd() const override {
+ return pos_ >= str_.length();
+ }
+
+ const std::string& GetStreamPadded64() const {
+ return str_;
+ }
+
+ private:
+ std::string str_;
+ size_t pos_;
+};
+
+TEST(NumBitsToNumWords, Word8) {
+ EXPECT_EQ(0u, NumBitsToNumWords<8>(0));
+ EXPECT_EQ(1u, NumBitsToNumWords<8>(1));
+ EXPECT_EQ(1u, NumBitsToNumWords<8>(7));
+ EXPECT_EQ(1u, NumBitsToNumWords<8>(8));
+ EXPECT_EQ(2u, NumBitsToNumWords<8>(9));
+ EXPECT_EQ(2u, NumBitsToNumWords<8>(16));
+ EXPECT_EQ(3u, NumBitsToNumWords<8>(17));
+ EXPECT_EQ(3u, NumBitsToNumWords<8>(23));
+ EXPECT_EQ(3u, NumBitsToNumWords<8>(24));
+ EXPECT_EQ(4u, NumBitsToNumWords<8>(25));
+}
+
+TEST(NumBitsToNumWords, Word64) {
+ EXPECT_EQ(0u, NumBitsToNumWords<64>(0));
+ EXPECT_EQ(1u, NumBitsToNumWords<64>(1));
+ EXPECT_EQ(1u, NumBitsToNumWords<64>(64));
+ EXPECT_EQ(2u, NumBitsToNumWords<64>(65));
+ EXPECT_EQ(2u, NumBitsToNumWords<64>(128));
+ EXPECT_EQ(3u, NumBitsToNumWords<64>(129));
+}
+
+TEST(ZigZagCoding, Encode) {
+ EXPECT_EQ(0u, EncodeZigZag(0));
+ EXPECT_EQ(1u, EncodeZigZag(-1));
+ EXPECT_EQ(2u, EncodeZigZag(1));
+ EXPECT_EQ(3u, EncodeZigZag(-2));
+ EXPECT_EQ(4u, EncodeZigZag(2));
+ EXPECT_EQ(5u, EncodeZigZag(-3));
+ EXPECT_EQ(6u, EncodeZigZag(3));
+ EXPECT_EQ(std::numeric_limits<uint64_t>::max() - 1,
+ EncodeZigZag(std::numeric_limits<int64_t>::max()));
+ EXPECT_EQ(std::numeric_limits<uint64_t>::max(),
+ EncodeZigZag(std::numeric_limits<int64_t>::min()));
+}
+
+TEST(ZigZagCoding, Decode) {
+ EXPECT_EQ(0, DecodeZigZag(0));
+ EXPECT_EQ(-1, DecodeZigZag(1));
+ EXPECT_EQ(1, DecodeZigZag(2));
+ EXPECT_EQ(-2, DecodeZigZag(3));
+ EXPECT_EQ(2, DecodeZigZag(4));
+ EXPECT_EQ(-3, DecodeZigZag(5));
+ EXPECT_EQ(3, DecodeZigZag(6));
+ EXPECT_EQ(std::numeric_limits<int64_t>::min(),
+ DecodeZigZag(std::numeric_limits<uint64_t>::max()));
+ EXPECT_EQ(std::numeric_limits<int64_t>::max(),
+ DecodeZigZag(std::numeric_limits<uint64_t>::max() - 1));
+}
+
+TEST(ZigZagCoding, Encode0) {
+ EXPECT_EQ(0u, EncodeZigZag(0, 0));
+ EXPECT_EQ(1u, EncodeZigZag(-1, 0));
+ EXPECT_EQ(2u, EncodeZigZag(1, 0));
+ EXPECT_EQ(3u, EncodeZigZag(-2, 0));
+ EXPECT_EQ(std::numeric_limits<uint64_t>::max() - 1,
+ EncodeZigZag(std::numeric_limits<int64_t>::max(), 0));
+ EXPECT_EQ(std::numeric_limits<uint64_t>::max(),
+ EncodeZigZag(std::numeric_limits<int64_t>::min(), 0));
+}
+
+TEST(ZigZagCoding, Decode0) {
+ EXPECT_EQ(0, DecodeZigZag(0, 0));
+ EXPECT_EQ(-1, DecodeZigZag(1, 0));
+ EXPECT_EQ(1, DecodeZigZag(2, 0));
+ EXPECT_EQ(-2, DecodeZigZag(3, 0));
+ EXPECT_EQ(std::numeric_limits<int64_t>::min(),
+ DecodeZigZag(std::numeric_limits<uint64_t>::max(), 0));
+ EXPECT_EQ(std::numeric_limits<int64_t>::max(),
+ DecodeZigZag(std::numeric_limits<uint64_t>::max() - 1, 0));
+}
+
+TEST(ZigZagCoding, Decode0SameAsNormalZigZag) {
+ for (int32_t i = -10000; i < 10000; i += 123) {
+ ASSERT_EQ(DecodeZigZag(i), DecodeZigZag(i, 0));
+ }
+}
+
+TEST(ZigZagCoding, Encode0SameAsNormalZigZag) {
+ for (uint32_t i = 0; i < 10000; i += 123) {
+ ASSERT_EQ(EncodeZigZag(i), EncodeZigZag(i, 0));
+ }
+}
+
+TEST(ZigZagCoding, Encode1) {
+ EXPECT_EQ(0u, EncodeZigZag(0, 1));
+ EXPECT_EQ(1u, EncodeZigZag(1, 1));
+ EXPECT_EQ(2u, EncodeZigZag(-1, 1));
+ EXPECT_EQ(3u, EncodeZigZag(-2, 1));
+ EXPECT_EQ(4u, EncodeZigZag(2, 1));
+ EXPECT_EQ(5u, EncodeZigZag(3, 1));
+ EXPECT_EQ(6u, EncodeZigZag(-3, 1));
+ EXPECT_EQ(7u, EncodeZigZag(-4, 1));
+ EXPECT_EQ(std::numeric_limits<uint64_t>::max() - 2,
+ EncodeZigZag(std::numeric_limits<int64_t>::max(), 1));
+ EXPECT_EQ(std::numeric_limits<uint64_t>::max() - 1,
+ EncodeZigZag(std::numeric_limits<int64_t>::min() + 1, 1));
+ EXPECT_EQ(std::numeric_limits<uint64_t>::max(),
+ EncodeZigZag(std::numeric_limits<int64_t>::min(), 1));
+}
+
+TEST(ZigZagCoding, Decode1) {
+ EXPECT_EQ(0, DecodeZigZag(0, 1));
+ EXPECT_EQ(1, DecodeZigZag(1, 1));
+ EXPECT_EQ(-1, DecodeZigZag(2, 1));
+ EXPECT_EQ(-2, DecodeZigZag(3, 1));
+ EXPECT_EQ(2, DecodeZigZag(4, 1));
+ EXPECT_EQ(3, DecodeZigZag(5, 1));
+ EXPECT_EQ(-3, DecodeZigZag(6, 1));
+ EXPECT_EQ(-4, DecodeZigZag(7, 1));
+ EXPECT_EQ(std::numeric_limits<int64_t>::min(),
+ DecodeZigZag(std::numeric_limits<uint64_t>::max(), 1));
+ EXPECT_EQ(std::numeric_limits<int64_t>::min() + 1,
+ DecodeZigZag(std::numeric_limits<uint64_t>::max() - 1, 1));
+ EXPECT_EQ(std::numeric_limits<int64_t>::max(),
+ DecodeZigZag(std::numeric_limits<uint64_t>::max() - 2, 1));
+}
+
+TEST(ZigZagCoding, Encode2) {
+ EXPECT_EQ(0u, EncodeZigZag(0, 2));
+ EXPECT_EQ(1u, EncodeZigZag(1, 2));
+ EXPECT_EQ(2u, EncodeZigZag(2, 2));
+ EXPECT_EQ(3u, EncodeZigZag(3, 2));
+ EXPECT_EQ(4u, EncodeZigZag(-1, 2));
+ EXPECT_EQ(5u, EncodeZigZag(-2, 2));
+ EXPECT_EQ(6u, EncodeZigZag(-3, 2));
+ EXPECT_EQ(7u, EncodeZigZag(-4, 2));
+ EXPECT_EQ(8u, EncodeZigZag(4, 2));
+ EXPECT_EQ(9u, EncodeZigZag(5, 2));
+ EXPECT_EQ(10u, EncodeZigZag(6, 2));
+ EXPECT_EQ(11u, EncodeZigZag(7, 2));
+ EXPECT_EQ(12u, EncodeZigZag(-5, 2));
+ EXPECT_EQ(13u, EncodeZigZag(-6, 2));
+ EXPECT_EQ(14u, EncodeZigZag(-7, 2));
+ EXPECT_EQ(15u, EncodeZigZag(-8, 2));
+ EXPECT_EQ(std::numeric_limits<uint64_t>::max() - 4,
+ EncodeZigZag(std::numeric_limits<int64_t>::max(), 2));
+ EXPECT_EQ(std::numeric_limits<uint64_t>::max() - 3,
+ EncodeZigZag(std::numeric_limits<int64_t>::min() + 3, 2));
+ EXPECT_EQ(std::numeric_limits<uint64_t>::max() - 2,
+ EncodeZigZag(std::numeric_limits<int64_t>::min() + 2, 2));
+ EXPECT_EQ(std::numeric_limits<uint64_t>::max() - 1,
+ EncodeZigZag(std::numeric_limits<int64_t>::min() + 1, 2));
+ EXPECT_EQ(std::numeric_limits<uint64_t>::max(),
+ EncodeZigZag(std::numeric_limits<int64_t>::min(), 2));
+}
+
+TEST(ZigZagCoding, Decode2) {
+ EXPECT_EQ(0, DecodeZigZag(0, 2));
+ EXPECT_EQ(1, DecodeZigZag(1, 2));
+ EXPECT_EQ(2, DecodeZigZag(2, 2));
+ EXPECT_EQ(3, DecodeZigZag(3, 2));
+ EXPECT_EQ(-1, DecodeZigZag(4, 2));
+ EXPECT_EQ(-2, DecodeZigZag(5, 2));
+ EXPECT_EQ(-3, DecodeZigZag(6, 2));
+ EXPECT_EQ(-4, DecodeZigZag(7, 2));
+ EXPECT_EQ(4, DecodeZigZag(8, 2));
+ EXPECT_EQ(5, DecodeZigZag(9, 2));
+ EXPECT_EQ(6, DecodeZigZag(10, 2));
+ EXPECT_EQ(7, DecodeZigZag(11, 2));
+ EXPECT_EQ(-5, DecodeZigZag(12, 2));
+ EXPECT_EQ(-6, DecodeZigZag(13, 2));
+ EXPECT_EQ(-7, DecodeZigZag(14, 2));
+ EXPECT_EQ(-8, DecodeZigZag(15, 2));
+ EXPECT_EQ(std::numeric_limits<int64_t>::min(),
+ DecodeZigZag(std::numeric_limits<uint64_t>::max(), 2));
+ EXPECT_EQ(std::numeric_limits<int64_t>::min() + 1,
+ DecodeZigZag(std::numeric_limits<uint64_t>::max() - 1, 2));
+ EXPECT_EQ(std::numeric_limits<int64_t>::min() + 2,
+ DecodeZigZag(std::numeric_limits<uint64_t>::max() - 2, 2));
+ EXPECT_EQ(std::numeric_limits<int64_t>::min() + 3,
+ DecodeZigZag(std::numeric_limits<uint64_t>::max() - 3, 2));
+ EXPECT_EQ(std::numeric_limits<int64_t>::max(),
+ DecodeZigZag(std::numeric_limits<uint64_t>::max() - 4, 2));
+}
+
+TEST(ZigZagCoding, Encode63) {
+ EXPECT_EQ(0u, EncodeZigZag(0, 63));
+
+ for (int64_t i = 0; i < 0xFFFFFFFF; i += 1234567) {
+ const int64_t positive_val = GetLowerBits(i * i * i + i * i, 63) | 1UL;
+ ASSERT_EQ(static_cast<uint64_t>(positive_val),
+ EncodeZigZag(positive_val, 63));
+ ASSERT_EQ((1ULL << 63) - 1 + positive_val, EncodeZigZag(-positive_val, 63));
+ }
+
+ EXPECT_EQ((1ULL << 63) - 1,
+ EncodeZigZag(std::numeric_limits<int64_t>::max(), 63));
+ EXPECT_EQ(std::numeric_limits<uint64_t>::max() - 1,
+ EncodeZigZag(std::numeric_limits<int64_t>::min() + 1, 63));
+ EXPECT_EQ(std::numeric_limits<uint64_t>::max(),
+ EncodeZigZag(std::numeric_limits<int64_t>::min(), 63));
+}
+
+TEST(BufToStream, UInt8_Empty) {
+ const std::string expected_bits = "";
+ std::vector<uint8_t> buffer = StreamToBuffer<uint8_t>(expected_bits);
+ EXPECT_TRUE(buffer.empty());
+ const std::string result_bits = BufferToStream(buffer);
+ EXPECT_EQ(expected_bits, result_bits);
+}
+
+TEST(BufToStream, UInt8_OneWord) {
+ const std::string expected_bits = "00101100";
+ std::vector<uint8_t> buffer = StreamToBuffer<uint8_t>(expected_bits);
+ EXPECT_EQ(
+ std::vector<uint8_t>(
+ {static_cast<uint8_t>(StreamToBitset<8>(expected_bits).to_ulong())}),
+ buffer);
+ const std::string result_bits = BufferToStream(buffer);
+ EXPECT_EQ(expected_bits, result_bits);
+}
+
+TEST(BufToStream, UInt8_MultipleWords) {
+ const std::string expected_bits = "00100010""01101010""01111101""00100010";
+ std::vector<uint8_t> buffer = StreamToBuffer<uint8_t>(expected_bits);
+ EXPECT_EQ(
+ std::vector<uint8_t>({
+ static_cast<uint8_t>(StreamToBitset<8>("00100010").to_ulong()),
+ static_cast<uint8_t>(StreamToBitset<8>("01101010").to_ulong()),
+ static_cast<uint8_t>(StreamToBitset<8>("01111101").to_ulong()),
+ static_cast<uint8_t>(StreamToBitset<8>("00100010").to_ulong()),
+ }), buffer);
+ const std::string result_bits = BufferToStream(buffer);
+ EXPECT_EQ(expected_bits, result_bits);
+}
+
+TEST(BufToStream, UInt64_Empty) {
+ const std::string expected_bits = "";
+ std::vector<uint64_t> buffer = StreamToBuffer<uint64_t>(expected_bits);
+ EXPECT_TRUE(buffer.empty());
+ const std::string result_bits = BufferToStream(buffer);
+ EXPECT_EQ(expected_bits, result_bits);
+}
+
+TEST(BufToStream, UInt64_OneWord) {
+ const std::string expected_bits =
+ "0001000111101110011001101010101000100010110011000100010010001000";
+ std::vector<uint64_t> buffer = StreamToBuffer<uint64_t>(expected_bits);
+ ASSERT_EQ(1u, buffer.size());
+ EXPECT_EQ(0x1122334455667788u, buffer[0]);
+ const std::string result_bits = BufferToStream(buffer);
+ EXPECT_EQ(expected_bits, result_bits);
+}
+
+TEST(BufToStream, UInt64_Unaligned) {
+ const std::string expected_bits =
+ "0010001001101010011111010010001001001010000111110010010010010101"
+ "0010001001101010011111111111111111111111";
+ std::vector<uint64_t> buffer = StreamToBuffer<uint64_t>(expected_bits);
+ EXPECT_EQ(std::vector<uint64_t>({
+ StreamToBits(expected_bits.substr(0, 64)),
+ StreamToBits(expected_bits.substr(64, 64)),
+ }), buffer);
+ const std::string result_bits = BufferToStream(buffer);
+ EXPECT_EQ(PadToWord<64>(expected_bits), result_bits);
+}
+
+TEST(BufToStream, UInt64_MultipleWords) {
+ const std::string expected_bits =
+ "0010001001101010011111010010001001001010000111110010010010010101"
+ "0010001001101010011111111111111111111111000111110010010010010111"
+ "0000000000000000000000000000000000000000000000000010010011111111";
+ std::vector<uint64_t> buffer = StreamToBuffer<uint64_t>(expected_bits);
+ EXPECT_EQ(std::vector<uint64_t>({
+ StreamToBits(expected_bits.substr(0, 64)),
+ StreamToBits(expected_bits.substr(64, 64)),
+ StreamToBits(expected_bits.substr(128, 64)),
+ }), buffer);
+ const std::string result_bits = BufferToStream(buffer);
+ EXPECT_EQ(expected_bits, result_bits);
+}
+
+TEST(PadToWord, Test) {
+ EXPECT_EQ("10100000", PadToWord<8>("101"));
+ EXPECT_EQ("10100000""00000000", PadToWord<16>("101"));
+ EXPECT_EQ("10100000""00000000""00000000""00000000",
+ PadToWord<32>("101"));
+ EXPECT_EQ("10100000""00000000""00000000""00000000"
+ "00000000""00000000""00000000""00000000",
+ PadToWord<64>("101"));
+}
+
+TEST(BitWriterStringStream, Empty) {
+ BitWriterStringStream writer;
+ EXPECT_EQ(0u, writer.GetNumBits());
+ EXPECT_EQ(0u, writer.GetDataSizeBytes());
+ EXPECT_EQ("", writer.GetStreamRaw());
+}
+
+TEST(BitWriterStringStream, WriteStream) {
+ BitWriterStringStream writer;
+ const std::string bits1 = "1011111111111111111";
+ writer.WriteStream(bits1);
+ EXPECT_EQ(19u, writer.GetNumBits());
+ EXPECT_EQ(3u, writer.GetDataSizeBytes());
+ EXPECT_EQ(bits1, writer.GetStreamRaw());
+
+ const std::string bits2 = "10100001010101010000111111111111111111111111111";
+ writer.WriteStream(bits2);
+ EXPECT_EQ(66u, writer.GetNumBits());
+ EXPECT_EQ(9u, writer.GetDataSizeBytes());
+ EXPECT_EQ(bits1 + bits2, writer.GetStreamRaw());
+}
+
+TEST(BitWriterStringStream, WriteBitSet) {
+ BitWriterStringStream writer;
+ const std::string bits1 = "10101";
+ writer.WriteBitset(StreamToBitset<16>(bits1));
+ EXPECT_EQ(16u, writer.GetNumBits());
+ EXPECT_EQ(2u, writer.GetDataSizeBytes());
+ EXPECT_EQ(PadToWord<16>(bits1), writer.GetStreamRaw());
+}
+
+TEST(BitWriterStringStream, WriteBits) {
+ BitWriterStringStream writer;
+ const uint64_t bits1 = 0x1 | 0x2 | 0x10;
+ writer.WriteBits(bits1, 5);
+ EXPECT_EQ(5u, writer.GetNumBits());
+ EXPECT_EQ(1u, writer.GetDataSizeBytes());
+ EXPECT_EQ("11001", writer.GetStreamRaw());
+}
+
+TEST(BitWriterStringStream, WriteMultiple) {
+ BitWriterStringStream writer;
+
+ std::string expected_result;
+ const std::string bits1 = "101001111111001100010000001110001111111100";
+ writer.WriteStream(bits1);
+
+ const std::string bits2 = "10100011000010010101";
+ writer.WriteBitset(StreamToBitset<20>(bits2));
+
+ const uint64_t val = 0x1 | 0x2 | 0x10;
+ const std::string bits3 = BitsToStream(val, 8);
+ writer.WriteBits(val, 8);
+
+ const std::string expected = bits1 + bits2 + bits3;
+
+ EXPECT_EQ(expected.length(), writer.GetNumBits());
+ EXPECT_EQ(9u, writer.GetDataSizeBytes());
+ EXPECT_EQ(expected, writer.GetStreamRaw());
+
+ EXPECT_EQ(PadToWord<8>(expected), BufferToStream(writer.GetDataCopy()));
+}
+
+TEST(BitWriterWord64, Empty) {
+ BitWriterWord64 writer;
+ EXPECT_EQ(0u, writer.GetNumBits());
+ EXPECT_EQ(0u, writer.GetDataSizeBytes());
+ EXPECT_EQ("", writer.GetStreamPadded64());
+}
+
+TEST(BitWriterWord64, WriteStream) {
+ BitWriterWord64 writer;
+ std::string expected;
+
+ {
+ const std::string bits = "101";
+ expected += bits;
+ writer.WriteStream(bits);
+ EXPECT_EQ(expected.length(), writer.GetNumBits());
+ EXPECT_EQ(1u, writer.GetDataSizeBytes());
+ EXPECT_EQ(PadToWord<64>(expected), writer.GetStreamPadded64());
+ }
+
+ {
+ const std::string bits = "10000111111111110000000";
+ expected += bits;
+ writer.WriteStream(bits);
+ EXPECT_EQ(expected.length(), writer.GetNumBits());
+ EXPECT_EQ(PadToWord<64>(expected), writer.GetStreamPadded64());
+ }
+
+ {
+ const std::string bits = "101001111111111100000111111111111100";
+ expected += bits;
+ writer.WriteStream(bits);
+ EXPECT_EQ(expected.length(), writer.GetNumBits());
+ EXPECT_EQ(PadToWord<64>(expected), writer.GetStreamPadded64());
+ }
+}
+
+TEST(BitWriterWord64, WriteBitset) {
+ BitWriterWord64 writer;
+ const std::string bits1 = "10101";
+ writer.WriteBitset(StreamToBitset<16>(bits1), 12);
+ EXPECT_EQ(12u, writer.GetNumBits());
+ EXPECT_EQ(2u, writer.GetDataSizeBytes());
+ EXPECT_EQ(PadToWord<64>(bits1), writer.GetStreamPadded64());
+}
+
+TEST(BitWriterWord64, WriteBits) {
+ BitWriterWord64 writer;
+ const uint64_t bits1 = 0x1 | 0x2 | 0x10;
+ writer.WriteBits(bits1, 5);
+ writer.WriteBits(bits1, 5);
+ writer.WriteBits(bits1, 5);
+ EXPECT_EQ(15u, writer.GetNumBits());
+ EXPECT_EQ(2u, writer.GetDataSizeBytes());
+ EXPECT_EQ(PadToWord<64>("110011100111001"), writer.GetStreamPadded64());
+}
+
+TEST(BitWriterWord64, ComparisonTestWriteLotsOfBits) {
+ BitWriterStringStream writer1;
+ BitWriterWord64 writer2(16384);
+
+ for (uint64_t i = 0; i < 65000; i += 25) {
+ writer1.WriteBits(i, 16);
+ writer2.WriteBits(i, 16);
+ ASSERT_EQ(writer1.GetNumBits(), writer2.GetNumBits());
+ }
+
+ EXPECT_EQ(PadToWord<64>(writer1.GetStreamRaw()),
+ writer2.GetStreamPadded64());
+}
+
+TEST(BitWriterWord64, ComparisonTestWriteLotsOfStreams) {
+ BitWriterStringStream writer1;
+ BitWriterWord64 writer2(16384);
+
+ for (int i = 0; i < 1000; ++i) {
+ std::string bits = "1111100000";
+ if (i % 2)
+ bits += "101010";
+ if (i % 3)
+ bits += "1110100";
+ if (i % 5)
+ bits += "1110100111111111111";
+ writer1.WriteStream(bits);
+ writer2.WriteStream(bits);
+ ASSERT_EQ(writer1.GetNumBits(), writer2.GetNumBits());
+ }
+
+ EXPECT_EQ(PadToWord<64>(writer1.GetStreamRaw()),
+ writer2.GetStreamPadded64());
+}
+
+TEST(BitWriterWord64, ComparisonTestWriteLotsOfBitsets) {
+ BitWriterStringStream writer1;
+ BitWriterWord64 writer2(16384);
+
+ for (uint64_t i = 0; i < 65000; i += 25) {
+ std::bitset<16> bits1(i);
+ std::bitset<24> bits2(i);
+ writer1.WriteBitset(bits1);
+ writer1.WriteBitset(bits2);
+ writer2.WriteBitset(bits1);
+ writer2.WriteBitset(bits2);
+ ASSERT_EQ(writer1.GetNumBits(), writer2.GetNumBits());
+ }
+
+ EXPECT_EQ(PadToWord<64>(writer1.GetStreamRaw()),
+ writer2.GetStreamPadded64());
+}
+
+TEST(GetLowerBits, Test) {
+ EXPECT_EQ(0u, GetLowerBits<uint8_t>(255, 0));
+ EXPECT_EQ(1u, GetLowerBits<uint8_t>(255, 1));
+ EXPECT_EQ(3u, GetLowerBits<uint8_t>(255, 2));
+ EXPECT_EQ(7u, GetLowerBits<uint8_t>(255, 3));
+ EXPECT_EQ(15u, GetLowerBits<uint8_t>(255, 4));
+ EXPECT_EQ(31u, GetLowerBits<uint8_t>(255, 5));
+ EXPECT_EQ(63u, GetLowerBits<uint8_t>(255, 6));
+ EXPECT_EQ(127u, GetLowerBits<uint8_t>(255, 7));
+ EXPECT_EQ(255u, GetLowerBits<uint8_t>(255, 8));
+ EXPECT_EQ(0xFFu, GetLowerBits<uint32_t>(0xFFFFFFFF, 8));
+ EXPECT_EQ(0xFFFFu, GetLowerBits<uint32_t>(0xFFFFFFFF, 16));
+ EXPECT_EQ(0xFFFFFFu, GetLowerBits<uint32_t>(0xFFFFFFFF, 24));
+ EXPECT_EQ(0xFFFFFFu, GetLowerBits<uint64_t>(0xFFFFFFFFFFFF, 24));
+ EXPECT_EQ(0xFFFFFFFFFFFFFFFFu,
+ GetLowerBits<uint64_t>(0xFFFFFFFFFFFFFFFFu, 64));
+ EXPECT_EQ(StreamToBits("1010001110"),
+ GetLowerBits<uint64_t>(
+ StreamToBits("1010001110111101111111"), 10));
+}
+
+TEST(BitReaderFromString, FromU8) {
+ std::vector<uint8_t> buffer = {
+ 0xAA, 0xBB, 0xCC, 0xDD,
+ };
+
+ const std::string total_stream =
+ "01010101""11011101""00110011""10111011";
+
+ BitReaderFromString reader(buffer);
+ EXPECT_EQ(PadToWord<64>(total_stream), reader.GetStreamPadded64());
+
+ uint64_t bits = 0;
+ EXPECT_EQ(2u, reader.ReadBits(&bits, 2));
+ EXPECT_EQ(PadToWord<64>("01"), BitsToStream(bits));
+ EXPECT_EQ(20u, reader.ReadBits(&bits, 20));
+ EXPECT_EQ(PadToWord<64>("01010111011101001100"), BitsToStream(bits));
+ EXPECT_EQ(20u, reader.ReadBits(&bits, 20));
+ EXPECT_EQ(PadToWord<64>("11101110110000000000"), BitsToStream(bits));
+ EXPECT_EQ(22u, reader.ReadBits(&bits, 30));
+ EXPECT_EQ(PadToWord<64>("0000000000000000000000"), BitsToStream(bits));
+ EXPECT_TRUE(reader.ReachedEnd());
+}
+
+TEST(BitReaderFromString, FromU64) {
+ std::vector<uint64_t> buffer = {
+ 0xAAAAAAAAAAAAAAAA,
+ 0xBBBBBBBBBBBBBBBB,
+ 0xCCCCCCCCCCCCCCCC,
+ 0xDDDDDDDDDDDDDDDD,
+ };
+
+ const std::string total_stream =
+ "0101010101010101010101010101010101010101010101010101010101010101"
+ "1101110111011101110111011101110111011101110111011101110111011101"
+ "0011001100110011001100110011001100110011001100110011001100110011"
+ "1011101110111011101110111011101110111011101110111011101110111011";
+
+ BitReaderFromString reader(buffer);
+ EXPECT_EQ(total_stream, reader.GetStreamPadded64());
+
+ uint64_t bits = 0;
+ size_t pos = 0;
+ size_t to_read = 5;
+ while (reader.ReadBits(&bits, to_read) > 0) {
+ EXPECT_EQ(BitsToStream(bits),
+ PadToWord<64>(total_stream.substr(pos, to_read)));
+ pos += to_read;
+ to_read = (to_read + 35) % 64 + 1;
+ }
+ EXPECT_TRUE(reader.ReachedEnd());
+}
+
+TEST(BitReaderWord64, ReadBitsSingleByte) {
+ BitReaderWord64 reader(std::vector<uint8_t>({uint8_t(0xF0)}));
+ EXPECT_FALSE(reader.ReachedEnd());
+
+ uint64_t bits = 0;
+ EXPECT_EQ(1u, reader.ReadBits(&bits, 1));
+ EXPECT_EQ(0u, bits);
+ EXPECT_EQ(2u, reader.ReadBits(&bits, 2));
+ EXPECT_EQ(0u, bits);
+ EXPECT_EQ(2u, reader.ReadBits(&bits, 2));
+ EXPECT_EQ(2u, bits);
+ EXPECT_EQ(2u, reader.ReadBits(&bits, 2));
+ EXPECT_EQ(3u, bits);
+ EXPECT_FALSE(reader.OnlyZeroesLeft());
+ EXPECT_FALSE(reader.ReachedEnd());
+ EXPECT_EQ(2u, reader.ReadBits(&bits, 2));
+ EXPECT_EQ(1u, bits);
+ EXPECT_TRUE(reader.OnlyZeroesLeft());
+ EXPECT_FALSE(reader.ReachedEnd());
+ EXPECT_EQ(55u, reader.ReadBits(&bits, 64));
+ EXPECT_EQ(0u, bits);
+ EXPECT_TRUE(reader.ReachedEnd());
+}
+
+TEST(BitReaderWord64, ReadBitsetSingleByte) {
+ BitReaderWord64 reader(std::vector<uint8_t>({uint8_t(0xCC)}));
+ std::bitset<4> bits;
+ EXPECT_EQ(2u, reader.ReadBitset(&bits, 2));
+ EXPECT_EQ(0u, bits.to_ullong());
+ EXPECT_EQ(2u, reader.ReadBitset(&bits, 2));
+ EXPECT_EQ(3u, bits.to_ullong());
+ EXPECT_FALSE(reader.OnlyZeroesLeft());
+ EXPECT_EQ(4u, reader.ReadBitset(&bits, 4));
+ EXPECT_EQ(12u, bits.to_ullong());
+ EXPECT_TRUE(reader.OnlyZeroesLeft());
+}
+
+TEST(BitReaderWord64, ReadStreamSingleByte) {
+ BitReaderWord64 reader(std::vector<uint8_t>({uint8_t(0xAA)}));
+ EXPECT_EQ("", reader.ReadStream(0));
+ EXPECT_EQ("0", reader.ReadStream(1));
+ EXPECT_EQ("101", reader.ReadStream(3));
+ EXPECT_EQ("01010000", reader.ReadStream(8));
+ EXPECT_TRUE(reader.OnlyZeroesLeft());
+ EXPECT_EQ("0000000000000000000000000000000000000000000000000000",
+ reader.ReadStream(64));
+ EXPECT_TRUE(reader.ReachedEnd());
+}
+
+TEST(BitReaderWord64, ReadStreamEmpty) {
+ std::vector<uint64_t> buffer;
+ BitReaderWord64 reader(std::move(buffer));
+ EXPECT_TRUE(reader.OnlyZeroesLeft());
+ EXPECT_TRUE(reader.ReachedEnd());
+ EXPECT_EQ("", reader.ReadStream(10));
+ EXPECT_TRUE(reader.ReachedEnd());
+}
+
+TEST(BitReaderWord64, ReadBitsTwoWords) {
+ std::vector<uint64_t> buffer = {
+ 0x0000000000000001,
+ 0x0000000000FFFFFF
+ };
+
+ BitReaderWord64 reader(std::move(buffer));
+
+ uint64_t bits = 0;
+ EXPECT_EQ(1u, reader.ReadBits(&bits, 1));
+ EXPECT_EQ(1u, bits);
+ EXPECT_EQ(62u, reader.ReadBits(&bits, 62));
+ EXPECT_EQ(0u, bits);
+ EXPECT_EQ(2u, reader.ReadBits(&bits, 2));
+ EXPECT_EQ(2u, bits);
+ EXPECT_EQ(3u, reader.ReadBits(&bits, 3));
+ EXPECT_EQ(7u, bits);
+ EXPECT_FALSE(reader.OnlyZeroesLeft());
+ EXPECT_EQ(32u, reader.ReadBits(&bits, 32));
+ EXPECT_EQ(0xFFFFFu, bits);
+ EXPECT_TRUE(reader.OnlyZeroesLeft());
+ EXPECT_FALSE(reader.ReachedEnd());
+ EXPECT_EQ(28u, reader.ReadBits(&bits, 32));
+ EXPECT_EQ(0u, bits);
+ EXPECT_TRUE(reader.ReachedEnd());
+}
+
+TEST(BitReaderWord64, FromU8) {
+ std::vector<uint8_t> buffer = {
+ 0xAA, 0xBB, 0xCC, 0xDD,
+ };
+
+ BitReaderWord64 reader(std::move(buffer));
+
+ uint64_t bits = 0;
+ EXPECT_EQ(2u, reader.ReadBits(&bits, 2));
+ EXPECT_EQ(PadToWord<64>("01"), BitsToStream(bits));
+ EXPECT_EQ(20u, reader.ReadBits(&bits, 20));
+ EXPECT_EQ(PadToWord<64>("01010111011101001100"), BitsToStream(bits));
+ EXPECT_EQ(20u, reader.ReadBits(&bits, 20));
+ EXPECT_EQ(PadToWord<64>("11101110110000000000"), BitsToStream(bits));
+ EXPECT_EQ(22u, reader.ReadBits(&bits, 30));
+ EXPECT_EQ(PadToWord<64>("0000000000000000000000"), BitsToStream(bits));
+ EXPECT_TRUE(reader.ReachedEnd());
+}
+
+TEST(BitReaderWord64, FromU64) {
+ std::vector<uint64_t> buffer = {
+ 0xAAAAAAAAAAAAAAAA,
+ 0xBBBBBBBBBBBBBBBB,
+ 0xCCCCCCCCCCCCCCCC,
+ 0xDDDDDDDDDDDDDDDD,
+ };
+
+ const std::string total_stream =
+ "0101010101010101010101010101010101010101010101010101010101010101"
+ "1101110111011101110111011101110111011101110111011101110111011101"
+ "0011001100110011001100110011001100110011001100110011001100110011"
+ "1011101110111011101110111011101110111011101110111011101110111011";
+
+ BitReaderWord64 reader(std::move(buffer));
+
+ uint64_t bits = 0;
+ size_t pos = 0;
+ size_t to_read = 5;
+ while (reader.ReadBits(&bits, to_read) > 0) {
+ EXPECT_EQ(BitsToStream(bits),
+ PadToWord<64>(total_stream.substr(pos, to_read)));
+ pos += to_read;
+ to_read = (to_read + 35) % 64 + 1;
+ }
+ EXPECT_TRUE(reader.ReachedEnd());
+}
+
+TEST(BitReaderWord64, ComparisonLotsOfU8) {
+ std::vector<uint8_t> buffer;
+ for(uint32_t i = 0; i < 10003; ++i) {
+ buffer.push_back(static_cast<uint8_t>(i % 255));
+ }
+
+ BitReaderFromString reader1(buffer);
+ BitReaderWord64 reader2(std::move(buffer));
+
+ uint64_t bits1 = 0, bits2 = 0;
+ size_t to_read = 5;
+ while (reader1.ReadBits(&bits1, to_read) > 0) {
+ reader2.ReadBits(&bits2, to_read);
+ EXPECT_EQ(bits1, bits2);
+ to_read = (to_read + 35) % 64 + 1;
+ }
+
+ EXPECT_EQ(0u, reader2.ReadBits(&bits2, 1));
+}
+
+TEST(BitReaderWord64, ComparisonLotsOfU64) {
+ std::vector<uint64_t> buffer;
+ for(uint64_t i = 0; i < 1000; ++i) {
+ buffer.push_back(i);
+ }
+
+ BitReaderFromString reader1(buffer);
+ BitReaderWord64 reader2(std::move(buffer));
+
+ uint64_t bits1 = 0, bits2 = 0;
+ size_t to_read = 5;
+ while (reader1.ReadBits(&bits1, to_read) > 0) {
+ reader2.ReadBits(&bits2, to_read);
+ EXPECT_EQ(bits1, bits2);
+ to_read = (to_read + 35) % 64 + 1;
+ }
+
+ EXPECT_EQ(0u, reader2.ReadBits(&bits2, 1));
+}
+
+TEST(ReadWriteWord64, ReadWriteLotsOfBits) {
+ BitWriterWord64 writer(16384);
+ for (uint64_t i = 0; i < 65000; i += 25) {
+ const uint64_t num_bits = i % 64 + 1;
+ const uint64_t bits = i >> (64 - num_bits);
+ writer.WriteBits(bits, size_t(num_bits));
+ }
+
+ BitReaderWord64 reader(writer.GetDataCopy());
+ for (uint64_t i = 0; i < 65000; i += 25) {
+ const uint64_t num_bits = i % 64 + 1;
+ const uint64_t expected_bits = i >> (64 - num_bits);
+ uint64_t bits = 0;
+ reader.ReadBits(&bits, size_t(num_bits));
+ EXPECT_EQ(expected_bits, bits);
+ }
+
+ EXPECT_TRUE(reader.OnlyZeroesLeft());
+}
+
+TEST(VariableWidthWrite, Write0U) {
+ BitWriterStringStream writer;
+ writer.WriteVariableWidthU64(0, 2);
+ EXPECT_EQ("000", writer.GetStreamRaw ());
+ writer.WriteVariableWidthU32(0, 2);
+ EXPECT_EQ("000""000", writer.GetStreamRaw());
+ writer.WriteVariableWidthU16(0, 2);
+ EXPECT_EQ("000""000""000", writer.GetStreamRaw());
+ writer.WriteVariableWidthU8(0, 2);
+ EXPECT_EQ("000""000""000""000", writer.GetStreamRaw());
+}
+
+TEST(VariableWidthWrite, Write0S) {
+ BitWriterStringStream writer;
+ writer.WriteVariableWidthS64(0, 2, 0);
+ EXPECT_EQ("000", writer.GetStreamRaw ());
+ writer.WriteVariableWidthS32(0, 2, 0);
+ EXPECT_EQ("000""000", writer.GetStreamRaw());
+ writer.WriteVariableWidthS16(0, 2, 0);
+ EXPECT_EQ("000""000""000", writer.GetStreamRaw());
+ writer.WriteVariableWidthS8(0, 2, 0);
+ EXPECT_EQ("000""000""000""000", writer.GetStreamRaw());
+}
+
+TEST(VariableWidthWrite, WriteSmallUnsigned) {
+ BitWriterStringStream writer;
+ writer.WriteVariableWidthU64(1, 2);
+ EXPECT_EQ("100", writer.GetStreamRaw ());
+ writer.WriteVariableWidthU32(2, 2);
+ EXPECT_EQ("100""010", writer.GetStreamRaw());
+ writer.WriteVariableWidthU16(3, 2);
+ EXPECT_EQ("100""010""110", writer.GetStreamRaw());
+ writer.WriteVariableWidthU8(4, 2);
+ EXPECT_EQ("100""010""110""001100", writer.GetStreamRaw());
+}
+
+TEST(VariableWidthWrite, WriteSmallSigned) {
+ BitWriterStringStream writer;
+ writer.WriteVariableWidthS64(1, 2, 0);
+ EXPECT_EQ("010", writer.GetStreamRaw ());
+ writer.WriteVariableWidthS64(-1, 2, 0);
+ EXPECT_EQ("010""100", writer.GetStreamRaw());
+ writer.WriteVariableWidthS16(3, 2, 0);
+ EXPECT_EQ("010""100""011100", writer.GetStreamRaw());
+ writer.WriteVariableWidthS8(-4, 2, 0);
+ EXPECT_EQ("010""100""011100""111100", writer.GetStreamRaw());
+}
+
+TEST(VariableWidthWrite, U64Val127ChunkLength7) {
+ BitWriterStringStream writer;
+ writer.WriteVariableWidthU64(127, 7);
+ EXPECT_EQ("1111111""0", writer.GetStreamRaw());
+}
+
+TEST(VariableWidthWrite, U32Val255ChunkLength7) {
+ BitWriterStringStream writer;
+ writer.WriteVariableWidthU32(255, 7);
+ EXPECT_EQ("1111111""1""1000000""0", writer.GetStreamRaw());
+}
+
+TEST(VariableWidthWrite, U16Val2ChunkLength4) {
+ BitWriterStringStream writer;
+ writer.WriteVariableWidthU16(2, 4);
+ EXPECT_EQ("0100""0", writer.GetStreamRaw());
+}
+
+TEST(VariableWidthWrite, U8Val128ChunkLength7) {
+ BitWriterStringStream writer;
+ writer.WriteVariableWidthU8(128, 7);
+ EXPECT_EQ("0000000""1""1", writer.GetStreamRaw());
+}
+
+TEST(VariableWidthWrite, U64ValAAAAChunkLength2) {
+ BitWriterStringStream writer;
+ writer.WriteVariableWidthU64(0xAAAA, 2);
+ EXPECT_EQ("01""1""01""1""01""1""01""1"
+ "01""1""01""1""01""1""01""0", writer.GetStreamRaw());
+}
+
+TEST(VariableWidthWrite, S8ValM128ChunkLength7) {
+ BitWriterStringStream writer;
+ writer.WriteVariableWidthS8(-128, 7, 0);
+ EXPECT_EQ("1111111""1""1", writer.GetStreamRaw());
+}
+
+TEST(VariableWidthRead, U64Val127ChunkLength7) {
+ BitReaderFromString reader("1111111""0");
+ uint64_t val = 0;
+ ASSERT_TRUE(reader.ReadVariableWidthU64(&val, 7));
+ EXPECT_EQ(127u, val);
+}
+
+TEST(VariableWidthRead, U32Val255ChunkLength7) {
+ BitReaderFromString reader("1111111""1""1000000""0");
+ uint32_t val = 0;
+ ASSERT_TRUE(reader.ReadVariableWidthU32(&val, 7));
+ EXPECT_EQ(255u, val);
+}
+
+TEST(VariableWidthRead, U16Val2ChunkLength4) {
+ BitReaderFromString reader("0100""0");
+ uint16_t val = 0;
+ ASSERT_TRUE(reader.ReadVariableWidthU16(&val, 4));
+ EXPECT_EQ(2u, val);
+}
+
+TEST(VariableWidthRead, U8Val128ChunkLength7) {
+ BitReaderFromString reader("0000000""1""1");
+ uint8_t val = 0;
+ ASSERT_TRUE(reader.ReadVariableWidthU8(&val, 7));
+ EXPECT_EQ(128u, val);
+}
+
+TEST(VariableWidthRead, U64ValAAAAChunkLength2) {
+ BitReaderFromString reader("01""1""01""1""01""1""01""1"
+ "01""1""01""1""01""1""01""0");
+ uint64_t val = 0;
+ ASSERT_TRUE(reader.ReadVariableWidthU64(&val, 2));
+ EXPECT_EQ(0xAAAAu, val);
+}
+
+TEST(VariableWidthRead, S8ValM128ChunkLength7) {
+ BitReaderFromString reader("1111111""1""1");
+ int8_t val = 0;
+ ASSERT_TRUE(reader.ReadVariableWidthS8(&val, 7, 0));
+ EXPECT_EQ(-128, val);
+}
+
+TEST(VariableWidthRead, FailTooShort) {
+ BitReaderFromString reader("00000001100000");
+ uint64_t val = 0;
+ ASSERT_FALSE(reader.ReadVariableWidthU64(&val, 7));
+}
+
+TEST(VariableWidthWriteRead, SingleWriteReadU64) {
+ for (uint64_t i = 0; i < 1000000; i += 1234) {
+ const uint64_t val = i * i * i;
+ const size_t chunk_length = size_t(i % 16 + 1);
+
+ BitWriterWord64 writer;
+ writer.WriteVariableWidthU64(val, chunk_length);
+
+ BitReaderWord64 reader(writer.GetDataCopy());
+ uint64_t read_val = 0;
+ ASSERT_TRUE(reader.ReadVariableWidthU64(&read_val, chunk_length));
+
+ ASSERT_EQ(val, read_val) << "Chunk length " << chunk_length;
+ }
+}
+
+TEST(VariableWidthWriteRead, SingleWriteReadS64) {
+ for (int64_t i = 0; i < 1000000; i += 4321) {
+ const int64_t val = i * i * (i % 2 ? -i : i);
+ const size_t chunk_length = size_t(i % 16 + 1);
+ const size_t zigzag_exponent = size_t(i % 13);
+
+ BitWriterWord64 writer;
+ writer.WriteVariableWidthS64(val, chunk_length, zigzag_exponent);
+
+ BitReaderWord64 reader(writer.GetDataCopy());
+ int64_t read_val = 0;
+ ASSERT_TRUE(reader.ReadVariableWidthS64(&read_val, chunk_length,
+ zigzag_exponent));
+
+ ASSERT_EQ(val, read_val) << "Chunk length " << chunk_length;
+ }
+}
+
+TEST(VariableWidthWriteRead, SingleWriteReadU32) {
+ for (uint32_t i = 0; i < 100000; i += 123) {
+ const uint32_t val = i * i;
+ const size_t chunk_length = i % 16 + 1;
+
+ BitWriterWord64 writer;
+ writer.WriteVariableWidthU32(val, chunk_length);
+
+ BitReaderWord64 reader(writer.GetDataCopy());
+ uint32_t read_val = 0;
+ ASSERT_TRUE(reader.ReadVariableWidthU32(&read_val, chunk_length));
+
+ ASSERT_EQ(val, read_val) << "Chunk length " << chunk_length;
+ }
+}
+
+TEST(VariableWidthWriteRead, SingleWriteReadS32) {
+ for (int32_t i = 0; i < 100000; i += 123) {
+ const int32_t val = i * (i % 2 ? -i : i);
+ const size_t chunk_length = i % 16 + 1;
+ const size_t zigzag_exponent = i % 11;
+
+ BitWriterWord64 writer;
+ writer.WriteVariableWidthS32(val, chunk_length, zigzag_exponent);
+
+ BitReaderWord64 reader(writer.GetDataCopy());
+ int32_t read_val = 0;
+ ASSERT_TRUE(reader.ReadVariableWidthS32(
+ &read_val, chunk_length, zigzag_exponent));
+
+ ASSERT_EQ(val, read_val) << "Chunk length " << chunk_length;
+ }
+}
+
+TEST(VariableWidthWriteRead, SingleWriteReadU16) {
+ for (int i = 0; i < 65536; i += 123) {
+ const uint16_t val = static_cast<int16_t>(i);
+ const size_t chunk_length = val % 10 + 1;
+
+ BitWriterWord64 writer;
+ writer.WriteVariableWidthU16(val, chunk_length);
+
+ BitReaderWord64 reader(writer.GetDataCopy());
+ uint16_t read_val = 0;
+ ASSERT_TRUE(reader.ReadVariableWidthU16(&read_val, chunk_length));
+
+ ASSERT_EQ(val, read_val) << "Chunk length " << chunk_length;
+ }
+}
+
+TEST(VariableWidthWriteRead, SingleWriteReadS16) {
+ for (int i = -32768; i < 32768; i += 123) {
+ const int16_t val = static_cast<int16_t>(i);
+ const size_t chunk_length = std::abs(i) % 10 + 1;
+ const size_t zigzag_exponent = std::abs(i) % 7;
+
+ BitWriterWord64 writer;
+ writer.WriteVariableWidthS16(val, chunk_length, zigzag_exponent);
+
+ BitReaderWord64 reader(writer.GetDataCopy());
+ int16_t read_val = 0;
+ ASSERT_TRUE(reader.ReadVariableWidthS16(&read_val, chunk_length,
+ zigzag_exponent));
+
+ ASSERT_EQ(val, read_val) << "Chunk length " << chunk_length;
+ }
+}
+
+TEST(VariableWidthWriteRead, SingleWriteReadU8) {
+ for (int i = 0; i < 256; ++i) {
+ const uint8_t val = static_cast<uint8_t>(i);
+ const size_t chunk_length = val % 5 + 1;
+
+ BitWriterWord64 writer;
+ writer.WriteVariableWidthU8(val, chunk_length);
+
+ BitReaderWord64 reader(writer.GetDataCopy());
+ uint8_t read_val = 0;
+ ASSERT_TRUE(reader.ReadVariableWidthU8(&read_val, chunk_length));
+
+ ASSERT_EQ(val, read_val) << "Chunk length " << chunk_length;
+ }
+}
+
+TEST(VariableWidthWriteRead, SingleWriteReadS8) {
+ for (int i = -128; i < 128; ++i) {
+ const int8_t val = static_cast<int8_t>(i);
+ const size_t chunk_length = std::abs(i) % 5 + 1;
+ const size_t zigzag_exponent = std::abs(i) % 3;
+
+ BitWriterWord64 writer;
+ writer.WriteVariableWidthS8(val, chunk_length, zigzag_exponent);
+
+ BitReaderWord64 reader(writer.GetDataCopy());
+ int8_t read_val = 0;
+ ASSERT_TRUE(reader.ReadVariableWidthS8(&read_val, chunk_length,
+ zigzag_exponent));
+
+ ASSERT_EQ(val, read_val) << "Chunk length " << chunk_length;
+ }
+}
+
+TEST(VariableWidthWriteRead, SmallNumbersChunkLength4) {
+ const std::vector<uint64_t> expected_values = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
+
+ BitWriterWord64 writer;
+ for (uint64_t val : expected_values) {
+ writer.WriteVariableWidthU64(val, 4);
+ }
+
+ EXPECT_EQ(50u, writer.GetNumBits());
+
+ std::vector<uint64_t> actual_values;
+ BitReaderWord64 reader(writer.GetDataCopy());
+ while(!reader.OnlyZeroesLeft()) {
+ uint64_t val = 0;
+ ASSERT_TRUE(reader.ReadVariableWidthU64(&val, 4));
+ actual_values.push_back(val);
+ }
+
+ EXPECT_EQ(expected_values, actual_values);
+}
+
+TEST(VariableWidthWriteRead, VariedNumbersChunkLength8) {
+ const std::vector<uint64_t> expected_values = {1000, 0, 255, 4294967296};
+ const size_t kExpectedNumBits = 9 * (2 + 1 + 1 + 5);
+
+ BitWriterWord64 writer;
+ for (uint64_t val : expected_values) {
+ writer.WriteVariableWidthU64(val, 8);
+ }
+
+ EXPECT_EQ(kExpectedNumBits, writer.GetNumBits());
+
+ std::vector<uint64_t> actual_values;
+ BitReaderWord64 reader(writer.GetDataCopy());
+ while (!reader.OnlyZeroesLeft()) {
+ uint64_t val = 0;
+ ASSERT_TRUE(reader.ReadVariableWidthU64(&val, 8));
+ actual_values.push_back(val);
+ }
+
+ EXPECT_EQ(expected_values, actual_values);
+}
+
+} // anonymous namespace
diff --git a/test/enum_string_mapping_test.cpp b/test/enum_string_mapping_test.cpp
index be1b77da..7060cb99 100644
--- a/test/enum_string_mapping_test.cpp
+++ b/test/enum_string_mapping_test.cpp
@@ -220,10 +220,14 @@ INSTANTIATE_TEST_CASE_P(AllCapabilities, CapabilityTest,
"DrawParameters"},
{SpvCapabilitySubgroupVoteKHR,
"SubgroupVoteKHR"},
+ {SpvCapabilityStorageBuffer16BitAccess,
+ "StorageBuffer16BitAccess"},
{SpvCapabilityStorageUniformBufferBlock16,
- "StorageUniformBufferBlock16"},
+ "StorageBuffer16BitAccess"}, // Preferred name
+ {SpvCapabilityUniformAndStorageBuffer16BitAccess,
+ "UniformAndStorageBuffer16BitAccess"},
{SpvCapabilityStorageUniform16,
- "StorageUniform16"},
+ "UniformAndStorageBuffer16BitAccess"}, // Preferred name
{SpvCapabilityStoragePushConstant16,
"StoragePushConstant16"},
{SpvCapabilityStorageInputOutput16,
@@ -244,6 +248,6 @@ INSTANTIATE_TEST_CASE_P(AllCapabilities, CapabilityTest,
"ShaderStereoViewNV"},
{SpvCapabilityPerViewAttributesNV,
"PerViewAttributesNV"}
- })));
+ })), );
} // anonymous namespace
diff --git a/test/operand_capabilities_test.cpp b/test/operand_capabilities_test.cpp
index e2911364..e13b934d 100644
--- a/test/operand_capabilities_test.cpp
+++ b/test/operand_capabilities_test.cpp
@@ -52,7 +52,8 @@ TEST_P(EnumCapabilityTest, Sample) {
spvOperandTableValueLookup(operandTable, get<1>(GetParam()).type,
get<1>(GetParam()).value, &entry));
EXPECT_THAT(ElementsIn(entry->capabilities),
- Eq(ElementsIn(get<1>(GetParam()).expected_capabilities)));
+ Eq(ElementsIn(get<1>(GetParam()).expected_capabilities)))
+ << " capability value " << get<1>(GetParam()).value;
}
#define CASE0(TYPE, VALUE) \
diff --git a/test/opt/CMakeLists.txt b/test/opt/CMakeLists.txt
index 2f1f482a..fcaefe26 100644
--- a/test/opt/CMakeLists.txt
+++ b/test/opt/CMakeLists.txt
@@ -28,11 +28,26 @@ add_spvtools_unittest(TARGET pass_manager
LIBS SPIRV-Tools-opt
)
+add_spvtools_unittest(TARGET optimizer
+ SRCS optimizer_test.cpp
+ LIBS SPIRV-Tools-opt
+)
+
add_spvtools_unittest(TARGET pass_strip_debug_info
SRCS strip_debug_info_test.cpp pass_utils.cpp
LIBS SPIRV-Tools-opt
)
+add_spvtools_unittest(TARGET pass_compact_ids
+ SRCS compact_ids_test.cpp pass_utils.cpp
+ LIBS SPIRV-Tools-opt
+)
+
+add_spvtools_unittest(TARGET pass_flatten_decoration
+ SRCS flatten_decoration_test.cpp pass_utils.cpp
+ LIBS SPIRV-Tools-opt
+)
+
add_spvtools_unittest(TARGET pass_freeze_spec_const
SRCS freeze_spec_const_test.cpp pass_utils.cpp
LIBS SPIRV-Tools-opt
@@ -43,6 +58,16 @@ add_spvtools_unittest(TARGET pass_inline
LIBS SPIRV-Tools-opt
)
+add_spvtools_unittest(TARGET pass_local_single_block_elim
+ SRCS local_single_block_elim.cpp pass_utils.cpp
+ LIBS SPIRV-Tools-opt
+)
+
+add_spvtools_unittest(TARGET pass_local_access_chain_convert
+ SRCS local_access_chain_convert_test.cpp pass_utils.cpp
+ LIBS SPIRV-Tools-opt
+)
+
add_spvtools_unittest(TARGET pass_eliminate_dead_const
SRCS eliminate_dead_const_test.cpp pass_utils.cpp
LIBS SPIRV-Tools-opt
diff --git a/test/opt/compact_ids_test.cpp b/test/opt/compact_ids_test.cpp
new file mode 100644
index 00000000..7094609e
--- /dev/null
+++ b/test/opt/compact_ids_test.cpp
@@ -0,0 +1,90 @@
+// Copyright (c) 2017 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <gmock/gmock.h>
+
+#include "pass_fixture.h"
+#include "pass_utils.h"
+
+namespace {
+
+using namespace spvtools;
+
+using CompactIdsTest = PassTest<::testing::Test>;
+
+TEST_F(CompactIdsTest, PassOff) {
+ const std::string before =
+R"(OpCapability Addresses
+OpCapability Kernel
+OpCapability GenericPointer
+OpCapability Linkage
+OpMemoryModel Physical32 OpenCL
+%99 = OpTypeInt 32 0
+%10 = OpTypeVector %99 2
+%20 = OpConstant %99 2
+%30 = OpTypeArray %99 %20
+)";
+
+ const std::string after = before;
+
+ SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
+ SinglePassRunAndCheck<opt::NullPass>(before, after, false, false);
+}
+
+TEST_F(CompactIdsTest, PassOn) {
+ const std::string before =
+R"(OpCapability Addresses
+OpCapability Kernel
+OpCapability GenericPointer
+OpCapability Linkage
+OpMemoryModel Physical32 OpenCL
+OpEntryPoint Kernel %3 "simple_kernel"
+%99 = OpTypeInt 32 0
+%10 = OpTypeVector %99 2
+%20 = OpConstant %99 2
+%30 = OpTypeArray %99 %20
+%40 = OpTypeVoid
+%50 = OpTypeFunction %40
+ %3 = OpFunction %40 None %50
+%70 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+ const std::string after =
+R"(OpCapability Addresses
+OpCapability Kernel
+OpCapability GenericPointer
+OpCapability Linkage
+OpMemoryModel Physical32 OpenCL
+OpEntryPoint Kernel %1 "simple_kernel"
+%2 = OpTypeInt 32 0
+%3 = OpTypeVector %2 2
+%4 = OpConstant %2 2
+%5 = OpTypeArray %2 %4
+%6 = OpTypeVoid
+%7 = OpTypeFunction %6
+%1 = OpFunction %6 None %7
+%8 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+ SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
+ SinglePassRunAndCheck<opt::CompactIdsPass>(before, after, false, false);
+}
+
+} // anonymous namespace
diff --git a/test/opt/flatten_decoration_test.cpp b/test/opt/flatten_decoration_test.cpp
new file mode 100644
index 00000000..8e6d979a
--- /dev/null
+++ b/test/opt/flatten_decoration_test.cpp
@@ -0,0 +1,234 @@
+// Copyright (c) 2017 Valve Corporation
+// Copyright (c) 2017 LunarG Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <gmock/gmock.h>
+
+#include "pass_fixture.h"
+#include "pass_utils.h"
+
+namespace {
+
+using namespace spvtools;
+
+// Returns the initial part of the assembly text for a valid
+// SPIR-V module, including instructions prior to decorations.
+std::string PreambleAssembly() {
+ return
+ R"(OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main" %hue %saturation %value
+OpName %main "main"
+OpName %void_fn "void_fn"
+OpName %hue "hue"
+OpName %saturation "saturation"
+OpName %value "value"
+OpName %entry "entry"
+OpName %Point "Point"
+OpName %Camera "Camera"
+)";
+}
+
+// Retuns types
+std::string TypesAndFunctionsAssembly() {
+ return
+ R"(%void = OpTypeVoid
+%void_fn = OpTypeFunction %void
+%float = OpTypeFloat 32
+%Point = OpTypeStruct %float %float %float
+%Camera = OpTypeStruct %float %float
+%_ptr_Input_float = OpTypePointer Input %float
+%hue = OpVariable %_ptr_Input_float Input
+%saturation = OpVariable %_ptr_Input_float Input
+%value = OpVariable %_ptr_Input_float Input
+%main = OpFunction %void None %void_fn
+%entry = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+}
+
+struct FlattenDecorationCase {
+ // Names and decorations before the pass.
+ std::string input;
+ // Names and decorations after the pass.
+ std::string expected;
+};
+
+using FlattenDecorationTest =
+ PassTest<::testing::TestWithParam<FlattenDecorationCase>>;
+
+TEST_P(FlattenDecorationTest, TransformsDecorations) {
+ const auto before =
+ PreambleAssembly() + GetParam().input + TypesAndFunctionsAssembly();
+ const auto after =
+ PreambleAssembly() + GetParam().expected + TypesAndFunctionsAssembly();
+
+ SinglePassRunAndCheck<opt::FlattenDecorationPass>(before, after, false, true);
+}
+
+INSTANTIATE_TEST_CASE_P(NoUses, FlattenDecorationTest,
+ ::testing::ValuesIn(std::vector<FlattenDecorationCase>{
+ // No OpDecorationGroup
+ {"", ""},
+
+ // OpDecorationGroup without any uses, and
+ // no OpName.
+ {"%group = OpDecorationGroup\n", ""},
+
+ // OpDecorationGroup without any uses, and
+ // with OpName targeting it. Proves you must
+ // remove the names as well.
+ {"OpName %group \"group\"\n"
+ "%group = OpDecorationGroup\n",
+ ""},
+
+ // OpDecorationGroup with decorations that
+ // target it, but no uses in OpGroupDecorate
+ // or OpGroupMemberDecorate instructions.
+ {"OpDecorate %group Flat\n"
+ "OpDecorate %group NoPerspective\n"
+ "%group = OpDecorationGroup\n",
+ ""},
+ }), );
+
+INSTANTIATE_TEST_CASE_P(OpGroupDecorate, FlattenDecorationTest,
+ ::testing::ValuesIn(std::vector<FlattenDecorationCase>{
+ // One OpGroupDecorate
+ {"OpName %group \"group\"\n"
+ "OpDecorate %group Flat\n"
+ "OpDecorate %group NoPerspective\n"
+ "%group = OpDecorationGroup\n"
+ "OpGroupDecorate %group %hue %saturation\n",
+ "OpDecorate %hue Flat\n"
+ "OpDecorate %saturation Flat\n"
+ "OpDecorate %hue NoPerspective\n"
+ "OpDecorate %saturation NoPerspective\n"},
+ // Multiple OpGroupDecorate
+ {"OpName %group \"group\"\n"
+ "OpDecorate %group Flat\n"
+ "OpDecorate %group NoPerspective\n"
+ "%group = OpDecorationGroup\n"
+ "OpGroupDecorate %group %hue %value\n"
+ "OpGroupDecorate %group %saturation\n",
+ "OpDecorate %hue Flat\n"
+ "OpDecorate %value Flat\n"
+ "OpDecorate %saturation Flat\n"
+ "OpDecorate %hue NoPerspective\n"
+ "OpDecorate %value NoPerspective\n"
+ "OpDecorate %saturation NoPerspective\n"},
+ // Two group decorations, interleaved
+ {"OpName %group0 \"group0\"\n"
+ "OpName %group1 \"group1\"\n"
+ "OpDecorate %group0 Flat\n"
+ "OpDecorate %group1 NoPerspective\n"
+ "%group0 = OpDecorationGroup\n"
+ "%group1 = OpDecorationGroup\n"
+ "OpGroupDecorate %group0 %hue %value\n"
+ "OpGroupDecorate %group1 %saturation\n",
+ "OpDecorate %hue Flat\n"
+ "OpDecorate %value Flat\n"
+ "OpDecorate %saturation NoPerspective\n"},
+ // Decoration with operands
+ {"OpName %group \"group\"\n"
+ "OpDecorate %group Location 42\n"
+ "%group = OpDecorationGroup\n"
+ "OpGroupDecorate %group %hue %saturation\n",
+ "OpDecorate %hue Location 42\n"
+ "OpDecorate %saturation Location 42\n"},
+ }), );
+
+INSTANTIATE_TEST_CASE_P(OpGroupMemberDecorate, FlattenDecorationTest,
+ ::testing::ValuesIn(std::vector<FlattenDecorationCase>{
+ // One OpGroupMemberDecorate
+ {"OpName %group \"group\"\n"
+ "OpDecorate %group Flat\n"
+ "OpDecorate %group Offset 16\n"
+ "%group = OpDecorationGroup\n"
+ "OpGroupMemberDecorate %group %Point 1\n",
+ "OpMemberDecorate %Point 1 Flat\n"
+ "OpMemberDecorate %Point 1 Offset 16\n"},
+ // Multiple OpGroupMemberDecorate using the same
+ // decoration group.
+ {"OpName %group \"group\"\n"
+ "OpDecorate %group Flat\n"
+ "OpDecorate %group NoPerspective\n"
+ "OpDecorate %group Offset 8\n"
+ "%group = OpDecorationGroup\n"
+ "OpGroupMemberDecorate %group %Point 2\n"
+ "OpGroupMemberDecorate %group %Camera 1\n",
+ "OpMemberDecorate %Point 2 Flat\n"
+ "OpMemberDecorate %Camera 1 Flat\n"
+ "OpMemberDecorate %Point 2 NoPerspective\n"
+ "OpMemberDecorate %Camera 1 NoPerspective\n"
+ "OpMemberDecorate %Point 2 Offset 8\n"
+ "OpMemberDecorate %Camera 1 Offset 8\n"},
+ // Two groups of member decorations, interleaved.
+ // Decoration is with and without operands.
+ {"OpName %group0 \"group0\"\n"
+ "OpName %group1 \"group1\"\n"
+ "OpDecorate %group0 Flat\n"
+ "OpDecorate %group0 Offset 8\n"
+ "OpDecorate %group1 NoPerspective\n"
+ "OpDecorate %group1 Offset 16\n"
+ "%group0 = OpDecorationGroup\n"
+ "%group1 = OpDecorationGroup\n"
+ "OpGroupMemberDecorate %group0 %Point 0\n"
+ "OpGroupMemberDecorate %group1 %Point 2\n",
+ "OpMemberDecorate %Point 0 Flat\n"
+ "OpMemberDecorate %Point 0 Offset 8\n"
+ "OpMemberDecorate %Point 2 NoPerspective\n"
+ "OpMemberDecorate %Point 2 Offset 16\n"},
+ }), );
+
+INSTANTIATE_TEST_CASE_P(UnrelatedDecorations, FlattenDecorationTest,
+ ::testing::ValuesIn(std::vector<FlattenDecorationCase>{
+ // A non-group non-member decoration is untouched.
+ {"OpDecorate %hue Centroid\n"
+ "OpDecorate %saturation Flat\n",
+ "OpDecorate %hue Centroid\n"
+ "OpDecorate %saturation Flat\n"},
+ // A non-group member decoration is untouched.
+ {"OpMemberDecorate %Point 0 Offset 0\n"
+ "OpMemberDecorate %Point 1 Offset 4\n"
+ "OpMemberDecorate %Point 1 Flat\n",
+ "OpMemberDecorate %Point 0 Offset 0\n"
+ "OpMemberDecorate %Point 1 Offset 4\n"
+ "OpMemberDecorate %Point 1 Flat\n"},
+ // A non-group non-member decoration survives any
+ // replacement of group decorations.
+ {"OpName %group \"group\"\n"
+ "OpDecorate %group Flat\n"
+ "OpDecorate %hue Centroid\n"
+ "OpDecorate %group NoPerspective\n"
+ "%group = OpDecorationGroup\n"
+ "OpGroupDecorate %group %hue %saturation\n",
+ "OpDecorate %hue Flat\n"
+ "OpDecorate %saturation Flat\n"
+ "OpDecorate %hue Centroid\n"
+ "OpDecorate %hue NoPerspective\n"
+ "OpDecorate %saturation NoPerspective\n"},
+ // A non-group member decoration survives any
+ // replacement of group decorations.
+ {"OpDecorate %group Offset 0\n"
+ "OpDecorate %group Flat\n"
+ "OpMemberDecorate %Point 1 Offset 4\n"
+ "%group = OpDecorationGroup\n"
+ "OpGroupMemberDecorate %group %Point 0\n",
+ "OpMemberDecorate %Point 0 Offset 0\n"
+ "OpMemberDecorate %Point 0 Flat\n"
+ "OpMemberDecorate %Point 1 Offset 4\n"},
+ }), );
+
+} // anonymous namespace
diff --git a/test/opt/inline_test.cpp b/test/opt/inline_test.cpp
index 66eb1761..30df2556 100644
--- a/test/opt/inline_test.cpp
+++ b/test/opt/inline_test.cpp
@@ -1358,6 +1358,250 @@ TEST_F(InlineTest, OpImageAndOpSampledImageOutOfBlock) {
/* skip_nop = */ false, /* do_validate = */ true);
}
+TEST_F(InlineTest, EarlyReturnFunctionInlined) {
+ // #version 140
+ //
+ // in vec4 BaseColor;
+ //
+ // float foo(vec4 bar)
+ // {
+ // if (bar.x < 0.0)
+ // return 0.0;
+ // return bar.x;
+ // }
+ //
+ // void main()
+ // {
+ // vec4 color = vec4(foo(BaseColor));
+ // gl_FragColor = color;
+ // }
+
+ const std::string predefs =
+ R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 140
+OpName %main "main"
+OpName %foo_vf4_ "foo(vf4;"
+OpName %bar "bar"
+OpName %color "color"
+OpName %BaseColor "BaseColor"
+OpName %param "param"
+OpName %gl_FragColor "gl_FragColor"
+%void = OpTypeVoid
+%10 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%v4float = OpTypeVector %float 4
+%_ptr_Function_v4float = OpTypePointer Function %v4float
+%14 = OpTypeFunction %float %_ptr_Function_v4float
+%uint = OpTypeInt 32 0
+%uint_0 = OpConstant %uint 0
+%_ptr_Function_float = OpTypePointer Function %float
+%float_0 = OpConstant %float 0
+%bool = OpTypeBool
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+%BaseColor = OpVariable %_ptr_Input_v4float Input
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+%gl_FragColor = OpVariable %_ptr_Output_v4float Output
+)";
+
+ const std::string nonEntryFuncs =
+ R"(%foo_vf4_ = OpFunction %float None %14
+%bar = OpFunctionParameter %_ptr_Function_v4float
+%27 = OpLabel
+%28 = OpAccessChain %_ptr_Function_float %bar %uint_0
+%29 = OpLoad %float %28
+%30 = OpFOrdLessThan %bool %29 %float_0
+OpSelectionMerge %31 None
+OpBranchConditional %30 %32 %31
+%32 = OpLabel
+OpReturnValue %float_0
+%31 = OpLabel
+%33 = OpAccessChain %_ptr_Function_float %bar %uint_0
+%34 = OpLoad %float %33
+OpReturnValue %34
+OpFunctionEnd
+)";
+
+ const std::string before =
+ R"(%main = OpFunction %void None %10
+%22 = OpLabel
+%color = OpVariable %_ptr_Function_v4float Function
+%param = OpVariable %_ptr_Function_v4float Function
+%23 = OpLoad %v4float %BaseColor
+OpStore %param %23
+%24 = OpFunctionCall %float %foo_vf4_ %param
+%25 = OpCompositeConstruct %v4float %24 %24 %24 %24
+OpStore %color %25
+%26 = OpLoad %v4float %color
+OpStore %gl_FragColor %26
+OpReturn
+OpFunctionEnd
+)";
+
+ const std::string after =
+ R"(%false = OpConstantFalse %bool
+%main = OpFunction %void None %10
+%22 = OpLabel
+%35 = OpVariable %_ptr_Function_float Function
+%color = OpVariable %_ptr_Function_v4float Function
+%param = OpVariable %_ptr_Function_v4float Function
+%23 = OpLoad %v4float %BaseColor
+OpStore %param %23
+OpBranch %36
+%36 = OpLabel
+OpLoopMerge %37 %38 None
+OpBranch %39
+%39 = OpLabel
+%40 = OpAccessChain %_ptr_Function_float %param %uint_0
+%41 = OpLoad %float %40
+%42 = OpFOrdLessThan %bool %41 %float_0
+OpSelectionMerge %43 None
+OpBranchConditional %42 %44 %43
+%44 = OpLabel
+OpStore %35 %float_0
+OpBranch %37
+%43 = OpLabel
+%45 = OpAccessChain %_ptr_Function_float %param %uint_0
+%46 = OpLoad %float %45
+OpStore %35 %46
+OpBranch %37
+%38 = OpLabel
+OpBranchConditional %false %36 %37
+%37 = OpLabel
+%24 = OpLoad %float %35
+%25 = OpCompositeConstruct %v4float %24 %24 %24 %24
+OpStore %color %25
+%26 = OpLoad %v4float %color
+OpStore %gl_FragColor %26
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndCheck<opt::InlinePass>(predefs + before + nonEntryFuncs,
+ predefs + after + nonEntryFuncs, false, true);
+}
+TEST_F(InlineTest, EarlyReturnInLoopIsNotInlined) {
+ // #version 140
+ //
+ // in vec4 BaseColor;
+ //
+ // float foo(vec4 bar)
+ // {
+ // while (true) {
+ // if (bar.x < 0.0)
+ // return 0.0;
+ // return bar.x;
+ // }
+ // }
+ //
+ // void main()
+ // {
+ // vec4 color = vec4(foo(BaseColor));
+ // gl_FragColor = color;
+ // }
+
+ const std::string assembly =
+ R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 140
+OpName %main "main"
+OpName %foo_vf4_ "foo(vf4;"
+OpName %bar "bar"
+OpName %color "color"
+OpName %BaseColor "BaseColor"
+OpName %param "param"
+OpName %gl_FragColor "gl_FragColor"
+%void = OpTypeVoid
+%10 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%v4float = OpTypeVector %float 4
+%_ptr_Function_v4float = OpTypePointer Function %v4float
+%14 = OpTypeFunction %float %_ptr_Function_v4float
+%bool = OpTypeBool
+%true = OpConstantTrue %bool
+%uint = OpTypeInt 32 0
+%uint_0 = OpConstant %uint 0
+%_ptr_Function_float = OpTypePointer Function %float
+%float_0 = OpConstant %float 0
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+%BaseColor = OpVariable %_ptr_Input_v4float Input
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+%gl_FragColor = OpVariable %_ptr_Output_v4float Output
+%main = OpFunction %void None %10
+%23 = OpLabel
+%color = OpVariable %_ptr_Function_v4float Function
+%param = OpVariable %_ptr_Function_v4float Function
+%24 = OpLoad %v4float %BaseColor
+OpStore %param %24
+%25 = OpFunctionCall %float %foo_vf4_ %param
+%26 = OpCompositeConstruct %v4float %25 %25 %25 %25
+OpStore %color %26
+%27 = OpLoad %v4float %color
+OpStore %gl_FragColor %27
+OpReturn
+OpFunctionEnd
+%foo_vf4_ = OpFunction %float None %14
+%bar = OpFunctionParameter %_ptr_Function_v4float
+%28 = OpLabel
+OpBranch %29
+%29 = OpLabel
+OpLoopMerge %30 %31 None
+OpBranch %32
+%32 = OpLabel
+OpBranchConditional %true %33 %30
+%33 = OpLabel
+%34 = OpAccessChain %_ptr_Function_float %bar %uint_0
+%35 = OpLoad %float %34
+%36 = OpFOrdLessThan %bool %35 %float_0
+OpSelectionMerge %37 None
+OpBranchConditional %36 %38 %37
+%38 = OpLabel
+OpReturnValue %float_0
+%37 = OpLabel
+%39 = OpAccessChain %_ptr_Function_float %bar %uint_0
+%40 = OpLoad %float %39
+OpReturnValue %40
+%31 = OpLabel
+OpBranch %29
+%30 = OpLabel
+%41 = OpUndef %float
+OpReturnValue %41
+OpFunctionEnd
+)";
+
+ SinglePassRunAndCheck<opt::InlinePass>(assembly, assembly, false, true);
+}
+
+TEST_F(InlineTest, ExternalFunctionIsNotInlined) {
+ // In particular, don't crash.
+ // See report https://github.com/KhronosGroup/SPIRV-Tools/issues/605
+ const std::string assembly =
+ R"(OpCapability Addresses
+OpCapability Kernel
+OpCapability Linkage
+OpMemoryModel Physical32 OpenCL
+OpEntryPoint Kernel %1 "entry_pt"
+OpDecorate %2 LinkageAttributes "external" Import
+%void = OpTypeVoid
+%4 = OpTypeFunction %void
+%2 = OpFunction %void None %4
+OpFunctionEnd
+%1 = OpFunction %void None %4
+%5 = OpLabel
+%6 = OpFunctionCall %void %2
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndCheck<opt::InlinePass>(assembly, assembly, false, true);
+}
+
// TODO(greg-lunarg): Add tests to verify handling of these cases:
//
// Empty modules
diff --git a/test/opt/local_access_chain_convert_test.cpp b/test/opt/local_access_chain_convert_test.cpp
new file mode 100644
index 00000000..ad37622d
--- /dev/null
+++ b/test/opt/local_access_chain_convert_test.cpp
@@ -0,0 +1,422 @@
+// Copyright (c) 2017 Valve Corporation
+// Copyright (c) 2017 LunarG Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "pass_fixture.h"
+#include "pass_utils.h"
+
+namespace {
+
+using namespace spvtools;
+
+using LocalAccessChainConvertTest = PassTest<::testing::Test>;
+
+TEST_F(LocalAccessChainConvertTest, StructOfVecsOfFloatConverted) {
+
+ // #version 140
+ //
+ // in vec4 BaseColor;
+ //
+ // struct S_t {
+ // vec4 v0;
+ // vec4 v1;
+ // };
+ //
+ // void main()
+ // {
+ // S_t s0;
+ // s0.v1 = BaseColor;
+ // gl_FragColor = s0.v1;
+ // }
+
+ const std::string predefs =
+ R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 140
+OpName %main "main"
+OpName %S_t "S_t"
+OpMemberName %S_t 0 "v0"
+OpMemberName %S_t 1 "v1"
+OpName %s0 "s0"
+OpName %BaseColor "BaseColor"
+OpName %gl_FragColor "gl_FragColor"
+%void = OpTypeVoid
+%8 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%v4float = OpTypeVector %float 4
+%S_t = OpTypeStruct %v4float %v4float
+%_ptr_Function_S_t = OpTypePointer Function %S_t
+%int = OpTypeInt 32 1
+%int_1 = OpConstant %int 1
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+%BaseColor = OpVariable %_ptr_Input_v4float Input
+%_ptr_Function_v4float = OpTypePointer Function %v4float
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+%gl_FragColor = OpVariable %_ptr_Output_v4float Output
+)";
+
+ const std::string before =
+ R"(%main = OpFunction %void None %8
+%17 = OpLabel
+%s0 = OpVariable %_ptr_Function_S_t Function
+%18 = OpLoad %v4float %BaseColor
+%19 = OpAccessChain %_ptr_Function_v4float %s0 %int_1
+OpStore %19 %18
+%20 = OpAccessChain %_ptr_Function_v4float %s0 %int_1
+%21 = OpLoad %v4float %20
+OpStore %gl_FragColor %21
+OpReturn
+OpFunctionEnd
+)";
+
+ const std::string after =
+ R"(%main = OpFunction %void None %8
+%17 = OpLabel
+%s0 = OpVariable %_ptr_Function_S_t Function
+%18 = OpLoad %v4float %BaseColor
+%22 = OpLoad %S_t %s0
+%23 = OpCompositeInsert %S_t %18 %22 1
+OpStore %s0 %23
+%24 = OpLoad %S_t %s0
+%25 = OpCompositeExtract %v4float %24 1
+OpStore %gl_FragColor %25
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndCheck<opt::LocalAccessChainConvertPass>(
+ predefs + before, predefs + after, true, true);
+}
+
+TEST_F(LocalAccessChainConvertTest, InBoundsAccessChainsConverted) {
+
+ // #version 140
+ //
+ // in vec4 BaseColor;
+ //
+ // struct S_t {
+ // vec4 v0;
+ // vec4 v1;
+ // };
+ //
+ // void main()
+ // {
+ // S_t s0;
+ // s0.v1 = BaseColor;
+ // gl_FragColor = s0.v1;
+ // }
+
+ const std::string predefs =
+ R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 140
+OpName %main "main"
+OpName %S_t "S_t"
+OpMemberName %S_t 0 "v0"
+OpMemberName %S_t 1 "v1"
+OpName %s0 "s0"
+OpName %BaseColor "BaseColor"
+OpName %gl_FragColor "gl_FragColor"
+%void = OpTypeVoid
+%8 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%v4float = OpTypeVector %float 4
+%S_t = OpTypeStruct %v4float %v4float
+%_ptr_Function_S_t = OpTypePointer Function %S_t
+%int = OpTypeInt 32 1
+%int_1 = OpConstant %int 1
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+%BaseColor = OpVariable %_ptr_Input_v4float Input
+%_ptr_Function_v4float = OpTypePointer Function %v4float
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+%gl_FragColor = OpVariable %_ptr_Output_v4float Output
+)";
+
+ const std::string before =
+ R"(%main = OpFunction %void None %8
+%17 = OpLabel
+%s0 = OpVariable %_ptr_Function_S_t Function
+%18 = OpLoad %v4float %BaseColor
+%19 = OpInBoundsAccessChain %_ptr_Function_v4float %s0 %int_1
+OpStore %19 %18
+%20 = OpInBoundsAccessChain %_ptr_Function_v4float %s0 %int_1
+%21 = OpLoad %v4float %20
+OpStore %gl_FragColor %21
+OpReturn
+OpFunctionEnd
+)";
+
+ const std::string after =
+ R"(%main = OpFunction %void None %8
+%17 = OpLabel
+%s0 = OpVariable %_ptr_Function_S_t Function
+%18 = OpLoad %v4float %BaseColor
+%22 = OpLoad %S_t %s0
+%23 = OpCompositeInsert %S_t %18 %22 1
+OpStore %s0 %23
+%24 = OpLoad %S_t %s0
+%25 = OpCompositeExtract %v4float %24 1
+OpStore %gl_FragColor %25
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndCheck<opt::LocalAccessChainConvertPass>(
+ predefs + before, predefs + after, true, true);
+}
+
+TEST_F(LocalAccessChainConvertTest, TwoUsesofSingleChainConverted) {
+
+ // #version 140
+ //
+ // in vec4 BaseColor;
+ //
+ // struct S_t {
+ // vec4 v0;
+ // vec4 v1;
+ // };
+ //
+ // void main()
+ // {
+ // S_t s0;
+ // s0.v1 = BaseColor;
+ // gl_FragColor = s0.v1;
+ // }
+
+ const std::string predefs =
+ R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 140
+OpName %main "main"
+OpName %S_t "S_t"
+OpMemberName %S_t 0 "v0"
+OpMemberName %S_t 1 "v1"
+OpName %s0 "s0"
+OpName %BaseColor "BaseColor"
+OpName %gl_FragColor "gl_FragColor"
+%void = OpTypeVoid
+%8 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%v4float = OpTypeVector %float 4
+%S_t = OpTypeStruct %v4float %v4float
+%_ptr_Function_S_t = OpTypePointer Function %S_t
+%int = OpTypeInt 32 1
+%int_1 = OpConstant %int 1
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+%BaseColor = OpVariable %_ptr_Input_v4float Input
+%_ptr_Function_v4float = OpTypePointer Function %v4float
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+%gl_FragColor = OpVariable %_ptr_Output_v4float Output
+)";
+
+ const std::string before =
+ R"(%main = OpFunction %void None %8
+%17 = OpLabel
+%s0 = OpVariable %_ptr_Function_S_t Function
+%18 = OpLoad %v4float %BaseColor
+%19 = OpAccessChain %_ptr_Function_v4float %s0 %int_1
+OpStore %19 %18
+%20 = OpLoad %v4float %19
+OpStore %gl_FragColor %20
+OpReturn
+OpFunctionEnd
+)";
+
+ const std::string after =
+ R"(%main = OpFunction %void None %8
+%17 = OpLabel
+%s0 = OpVariable %_ptr_Function_S_t Function
+%18 = OpLoad %v4float %BaseColor
+%21 = OpLoad %S_t %s0
+%22 = OpCompositeInsert %S_t %18 %21 1
+OpStore %s0 %22
+%23 = OpLoad %S_t %s0
+%24 = OpCompositeExtract %v4float %23 1
+OpStore %gl_FragColor %24
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndCheck<opt::LocalAccessChainConvertPass>(
+ predefs + before, predefs + after, true, true);
+}
+
+TEST_F(LocalAccessChainConvertTest,
+ UntargetedTypeNotConverted) {
+
+ // #version 140
+ //
+ // in vec4 BaseColor;
+ //
+ // struct S1_t {
+ // vec4 v1;
+ // };
+ //
+ // struct S2_t {
+ // vec4 v2;
+ // S1_t s1;
+ // };
+ //
+ // void main()
+ // {
+ // S2_t s2;
+ // s2.s1.v1 = BaseColor;
+ // gl_FragColor = s2.s1.v1;
+ // }
+
+ const std::string assembly =
+ R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 140
+OpName %main "main"
+OpName %S1_t "S1_t"
+OpMemberName %S1_t 0 "v1"
+OpName %S2_t "S2_t"
+OpMemberName %S2_t 0 "v2"
+OpMemberName %S2_t 1 "s1"
+OpName %s2 "s2"
+OpName %BaseColor "BaseColor"
+OpName %gl_FragColor "gl_FragColor"
+%void = OpTypeVoid
+%9 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%v4float = OpTypeVector %float 4
+%S1_t = OpTypeStruct %v4float
+%S2_t = OpTypeStruct %v4float %S1_t
+%_ptr_Function_S2_t = OpTypePointer Function %S2_t
+%int = OpTypeInt 32 1
+%int_1 = OpConstant %int 1
+%int_0 = OpConstant %int 0
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+%BaseColor = OpVariable %_ptr_Input_v4float Input
+%_ptr_Function_v4float = OpTypePointer Function %v4float
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+%gl_FragColor = OpVariable %_ptr_Output_v4float Output
+%main = OpFunction %void None %9
+%19 = OpLabel
+%s2 = OpVariable %_ptr_Function_S2_t Function
+%20 = OpLoad %v4float %BaseColor
+%21 = OpAccessChain %_ptr_Function_v4float %s2 %int_1 %int_0
+OpStore %21 %20
+%22 = OpAccessChain %_ptr_Function_v4float %s2 %int_1 %int_0
+%23 = OpLoad %v4float %22
+OpStore %gl_FragColor %23
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndCheck<opt::LocalAccessChainConvertPass>(
+ assembly, assembly, false, true);
+}
+
+TEST_F(LocalAccessChainConvertTest,
+ DynamicallyIndexedVarNotConverted) {
+
+ // #version 140
+ //
+ // in vec4 BaseColor;
+ // flat in int Idx;
+ // in float Bi;
+ //
+ // struct S_t {
+ // vec4 v0;
+ // vec4 v1;
+ // };
+ //
+ // void main()
+ // {
+ // S_t s0;
+ // s0.v1 = BaseColor;
+ // s0.v1[Idx] = Bi;
+ // gl_FragColor = s0.v1;
+ // }
+
+ const std::string assembly =
+ R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main" %BaseColor %Idx %Bi %gl_FragColor
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 140
+OpName %main "main"
+OpName %S_t "S_t"
+OpMemberName %S_t 0 "v0"
+OpMemberName %S_t 1 "v1"
+OpName %s0 "s0"
+OpName %BaseColor "BaseColor"
+OpName %Idx "Idx"
+OpName %Bi "Bi"
+OpName %gl_FragColor "gl_FragColor"
+OpDecorate %Idx Flat
+%void = OpTypeVoid
+%10 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%v4float = OpTypeVector %float 4
+%S_t = OpTypeStruct %v4float %v4float
+%_ptr_Function_S_t = OpTypePointer Function %S_t
+%int = OpTypeInt 32 1
+%int_1 = OpConstant %int 1
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+%BaseColor = OpVariable %_ptr_Input_v4float Input
+%_ptr_Function_v4float = OpTypePointer Function %v4float
+%_ptr_Input_int = OpTypePointer Input %int
+%Idx = OpVariable %_ptr_Input_int Input
+%_ptr_Input_float = OpTypePointer Input %float
+%Bi = OpVariable %_ptr_Input_float Input
+%_ptr_Function_float = OpTypePointer Function %float
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+%gl_FragColor = OpVariable %_ptr_Output_v4float Output
+%main = OpFunction %void None %10
+%22 = OpLabel
+%s0 = OpVariable %_ptr_Function_S_t Function
+%23 = OpLoad %v4float %BaseColor
+%24 = OpAccessChain %_ptr_Function_v4float %s0 %int_1
+OpStore %24 %23
+%25 = OpLoad %int %Idx
+%26 = OpLoad %float %Bi
+%27 = OpAccessChain %_ptr_Function_float %s0 %int_1 %25
+OpStore %27 %26
+%28 = OpAccessChain %_ptr_Function_v4float %s0 %int_1
+%29 = OpLoad %v4float %28
+OpStore %gl_FragColor %29
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndCheck<opt::LocalAccessChainConvertPass>(
+ assembly, assembly, false, true);
+}
+
+// TODO(greg-lunarg): Add tests to verify handling of these cases:
+//
+// Assorted vector and matrix types
+// Assorted struct array types
+// Assorted scalar types
+// Assorted non-target types
+// OpInBoundsAccessChain
+// Others?
+
+} // anonymous namespace
diff --git a/test/opt/local_single_block_elim.cpp b/test/opt/local_single_block_elim.cpp
new file mode 100644
index 00000000..8c193bda
--- /dev/null
+++ b/test/opt/local_single_block_elim.cpp
@@ -0,0 +1,469 @@
+// Copyright (c) 2017 Valve Corporation
+// Copyright (c) 2017 LunarG Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "pass_fixture.h"
+#include "pass_utils.h"
+
+template <typename T> std::vector<T> concat(const std::vector<T> &a, const std::vector<T> &b) {
+ std::vector<T> ret = std::vector<T>();
+ std::copy(a.begin(), a.end(), back_inserter(ret));
+ std::copy(b.begin(), b.end(), back_inserter(ret));
+ return ret;
+}
+
+namespace {
+
+using namespace spvtools;
+
+using LocalSingleBlockLoadStoreElimTest = PassTest<::testing::Test>;
+
+TEST_F(LocalSingleBlockLoadStoreElimTest, SimpleStoreLoadElim) {
+ // #version 140
+ //
+ // in vec4 BaseColor;
+ //
+ // void main()
+ // {
+ // vec4 v = BaseColor;
+ // gl_FragColor = v;
+ // }
+
+ const std::string predefs =
+ R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 140
+OpName %main "main"
+OpName %v "v"
+OpName %BaseColor "BaseColor"
+OpName %gl_FragColor "gl_FragColor"
+%void = OpTypeVoid
+%7 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%v4float = OpTypeVector %float 4
+%_ptr_Function_v4float = OpTypePointer Function %v4float
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+%BaseColor = OpVariable %_ptr_Input_v4float Input
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+%gl_FragColor = OpVariable %_ptr_Output_v4float Output
+)";
+
+ const std::string before =
+ R"(%main = OpFunction %void None %7
+%13 = OpLabel
+%v = OpVariable %_ptr_Function_v4float Function
+%14 = OpLoad %v4float %BaseColor
+OpStore %v %14
+%15 = OpLoad %v4float %v
+OpStore %gl_FragColor %15
+OpReturn
+OpFunctionEnd
+)";
+
+ const std::string after =
+ R"(%main = OpFunction %void None %7
+%13 = OpLabel
+%v = OpVariable %_ptr_Function_v4float Function
+%14 = OpLoad %v4float %BaseColor
+OpStore %gl_FragColor %14
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndCheck<opt::LocalSingleBlockLoadStoreElimPass>(
+ predefs + before, predefs + after, true, true);
+}
+
+TEST_F(LocalSingleBlockLoadStoreElimTest, SimpleLoadLoadElim) {
+ // #version 140
+ //
+ // in vec4 BaseColor;
+ // in float fi;
+ //
+ // void main()
+ // {
+ // vec4 v = BaseColor;
+ // if (fi < 0)
+ // v = vec4(0.0);
+ // gl_FragData[0] = v;
+ // gl_FragData[1] = v;
+ // }
+
+ const std::string predefs =
+ R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main" %BaseColor %fi %gl_FragData
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 140
+OpName %main "main"
+OpName %v "v"
+OpName %BaseColor "BaseColor"
+OpName %fi "fi"
+OpName %gl_FragData "gl_FragData"
+%void = OpTypeVoid
+%8 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%v4float = OpTypeVector %float 4
+%_ptr_Function_v4float = OpTypePointer Function %v4float
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+%BaseColor = OpVariable %_ptr_Input_v4float Input
+%_ptr_Input_float = OpTypePointer Input %float
+%fi = OpVariable %_ptr_Input_float Input
+%float_0 = OpConstant %float 0
+%bool = OpTypeBool
+%16 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0
+%uint = OpTypeInt 32 0
+%uint_32 = OpConstant %uint 32
+%_arr_v4float_uint_32 = OpTypeArray %v4float %uint_32
+%_ptr_Output__arr_v4float_uint_32 = OpTypePointer Output %_arr_v4float_uint_32
+%gl_FragData = OpVariable %_ptr_Output__arr_v4float_uint_32 Output
+%int = OpTypeInt 32 1
+%int_0 = OpConstant %int 0
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+%int_1 = OpConstant %int 1
+)";
+
+ const std::string before =
+ R"(%main = OpFunction %void None %8
+%25 = OpLabel
+%v = OpVariable %_ptr_Function_v4float Function
+%26 = OpLoad %v4float %BaseColor
+OpStore %v %26
+%27 = OpLoad %float %fi
+%28 = OpFOrdLessThan %bool %27 %float_0
+OpSelectionMerge %29 None
+OpBranchConditional %28 %30 %29
+%30 = OpLabel
+OpStore %v %16
+OpBranch %29
+%29 = OpLabel
+%31 = OpLoad %v4float %v
+%32 = OpAccessChain %_ptr_Output_v4float %gl_FragData %int_0
+OpStore %32 %31
+%33 = OpLoad %v4float %v
+%34 = OpAccessChain %_ptr_Output_v4float %gl_FragData %int_1
+OpStore %34 %33
+OpReturn
+OpFunctionEnd
+)";
+
+ const std::string after =
+ R"(%main = OpFunction %void None %8
+%25 = OpLabel
+%v = OpVariable %_ptr_Function_v4float Function
+%26 = OpLoad %v4float %BaseColor
+OpStore %v %26
+%27 = OpLoad %float %fi
+%28 = OpFOrdLessThan %bool %27 %float_0
+OpSelectionMerge %29 None
+OpBranchConditional %28 %30 %29
+%30 = OpLabel
+OpStore %v %16
+OpBranch %29
+%29 = OpLabel
+%31 = OpLoad %v4float %v
+%32 = OpAccessChain %_ptr_Output_v4float %gl_FragData %int_0
+OpStore %32 %31
+%34 = OpAccessChain %_ptr_Output_v4float %gl_FragData %int_1
+OpStore %34 %31
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndCheck<opt::LocalSingleBlockLoadStoreElimPass>(
+ predefs + before, predefs + after, true, true);
+}
+
+TEST_F(LocalSingleBlockLoadStoreElimTest,
+ NoStoreElimIfInterveningAccessChainLoad) {
+ //
+ // Note that even though the Load to %v is eliminated, the Store to %v
+ // is not eliminated due to the following access chain reference.
+ //
+ // #version 140
+ //
+ // in vec4 BaseColor;
+ // flat in int Idx;
+ //
+ // void main()
+ // {
+ // vec4 v = BaseColor;
+ // float f = v[Idx];
+ // gl_FragColor = v/f;
+ // }
+
+ const std::string predefs =
+ R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main" %BaseColor %Idx %gl_FragColor
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 140
+OpName %main "main"
+OpName %v "v"
+OpName %BaseColor "BaseColor"
+OpName %f "f"
+OpName %Idx "Idx"
+OpName %gl_FragColor "gl_FragColor"
+OpDecorate %Idx Flat
+%void = OpTypeVoid
+%9 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%v4float = OpTypeVector %float 4
+%_ptr_Function_v4float = OpTypePointer Function %v4float
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+%BaseColor = OpVariable %_ptr_Input_v4float Input
+%_ptr_Function_float = OpTypePointer Function %float
+%int = OpTypeInt 32 1
+%_ptr_Input_int = OpTypePointer Input %int
+%Idx = OpVariable %_ptr_Input_int Input
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+%gl_FragColor = OpVariable %_ptr_Output_v4float Output
+)";
+
+ const std::string before =
+ R"(%main = OpFunction %void None %9
+%18 = OpLabel
+%v = OpVariable %_ptr_Function_v4float Function
+%f = OpVariable %_ptr_Function_float Function
+%19 = OpLoad %v4float %BaseColor
+OpStore %v %19
+%20 = OpLoad %int %Idx
+%21 = OpAccessChain %_ptr_Function_float %v %20
+%22 = OpLoad %float %21
+OpStore %f %22
+%23 = OpLoad %v4float %v
+%24 = OpLoad %float %f
+%25 = OpCompositeConstruct %v4float %24 %24 %24 %24
+%26 = OpFDiv %v4float %23 %25
+OpStore %gl_FragColor %26
+OpReturn
+OpFunctionEnd
+)";
+
+ const std::string after =
+ R"(%main = OpFunction %void None %9
+%18 = OpLabel
+%v = OpVariable %_ptr_Function_v4float Function
+%f = OpVariable %_ptr_Function_float Function
+%19 = OpLoad %v4float %BaseColor
+OpStore %v %19
+%20 = OpLoad %int %Idx
+%21 = OpAccessChain %_ptr_Function_float %v %20
+%22 = OpLoad %float %21
+%25 = OpCompositeConstruct %v4float %22 %22 %22 %22
+%26 = OpFDiv %v4float %19 %25
+OpStore %gl_FragColor %26
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndCheck<opt::LocalSingleBlockLoadStoreElimPass>(
+ predefs + before, predefs + after, true, true);
+}
+
+TEST_F(LocalSingleBlockLoadStoreElimTest, NoElimIfInterveningAccessChainStore) {
+ // #version 140
+ //
+ // in vec4 BaseColor;
+ // flat in int Idx;
+ //
+ // void main()
+ // {
+ // vec4 v = BaseColor;
+ // v[Idx] = 0;
+ // gl_FragColor = v;
+ // }
+
+ const std::string assembly =
+ R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main" %BaseColor %Idx %gl_FragColor
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 140
+OpName %main "main"
+OpName %v "v"
+OpName %BaseColor "BaseColor"
+OpName %Idx "Idx"
+OpName %gl_FragColor "gl_FragColor"
+OpDecorate %Idx Flat
+%void = OpTypeVoid
+%8 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%v4float = OpTypeVector %float 4
+%_ptr_Function_v4float = OpTypePointer Function %v4float
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+%BaseColor = OpVariable %_ptr_Input_v4float Input
+%int = OpTypeInt 32 1
+%_ptr_Input_int = OpTypePointer Input %int
+%Idx = OpVariable %_ptr_Input_int Input
+%float_0 = OpConstant %float 0
+%_ptr_Function_float = OpTypePointer Function %float
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+%gl_FragColor = OpVariable %_ptr_Output_v4float Output
+%main = OpFunction %void None %8
+%18 = OpLabel
+%v = OpVariable %_ptr_Function_v4float Function
+%19 = OpLoad %v4float %BaseColor
+OpStore %v %19
+%20 = OpLoad %int %Idx
+%21 = OpAccessChain %_ptr_Function_float %v %20
+OpStore %21 %float_0
+%22 = OpLoad %v4float %v
+OpStore %gl_FragColor %22
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndCheck<opt::LocalSingleBlockLoadStoreElimPass>(
+ assembly, assembly, false, true);
+}
+
+TEST_F(LocalSingleBlockLoadStoreElimTest, NoElimIfInterveningFunctionCall) {
+ // #version 140
+ //
+ // in vec4 BaseColor;
+ //
+ // void foo() {
+ // }
+ //
+ // void main()
+ // {
+ // vec4 v = BaseColor;
+ // foo();
+ // gl_FragColor = v;
+ // }
+
+ const std::string assembly =
+ R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 140
+OpName %main "main"
+OpName %foo_ "foo("
+OpName %v "v"
+OpName %BaseColor "BaseColor"
+OpName %gl_FragColor "gl_FragColor"
+%void = OpTypeVoid
+%8 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%v4float = OpTypeVector %float 4
+%_ptr_Function_v4float = OpTypePointer Function %v4float
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+%BaseColor = OpVariable %_ptr_Input_v4float Input
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+%gl_FragColor = OpVariable %_ptr_Output_v4float Output
+%main = OpFunction %void None %8
+%14 = OpLabel
+%v = OpVariable %_ptr_Function_v4float Function
+%15 = OpLoad %v4float %BaseColor
+OpStore %v %15
+%16 = OpFunctionCall %void %foo_
+%17 = OpLoad %v4float %v
+OpStore %gl_FragColor %17
+OpReturn
+OpFunctionEnd
+%foo_ = OpFunction %void None %8
+%18 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndCheck<opt::LocalSingleBlockLoadStoreElimPass>(
+ assembly, assembly, false, true);
+}
+
+TEST_F(LocalSingleBlockLoadStoreElimTest, NoElimIfCopyObjectInFunction) {
+ // Note: SPIR-V hand edited to insert CopyObject
+ //
+ // #version 140
+ //
+ // in vec4 BaseColor;
+ //
+ // void main()
+ // {
+ // vec4 v1 = BaseColor;
+ // gl_FragData[0] = v1;
+ // vec4 v2 = BaseColor * 0.5;
+ // gl_FragData[1] = v2;
+ // }
+
+ const std::string assembly =
+ R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main" %BaseColor %gl_FragData
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 140
+OpName %main "main"
+OpName %v1 "v1"
+OpName %BaseColor "BaseColor"
+OpName %gl_FragData "gl_FragData"
+OpName %v2 "v2"
+%void = OpTypeVoid
+%8 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%v4float = OpTypeVector %float 4
+%_ptr_Function_v4float = OpTypePointer Function %v4float
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+%BaseColor = OpVariable %_ptr_Input_v4float Input
+%uint = OpTypeInt 32 0
+%uint_32 = OpConstant %uint 32
+%_arr_v4float_uint_32 = OpTypeArray %v4float %uint_32
+%_ptr_Output__arr_v4float_uint_32 = OpTypePointer Output %_arr_v4float_uint_32
+%gl_FragData = OpVariable %_ptr_Output__arr_v4float_uint_32 Output
+%int = OpTypeInt 32 1
+%int_0 = OpConstant %int 0
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+%float_0_5 = OpConstant %float 0.5
+%int_1 = OpConstant %int 1
+%main = OpFunction %void None %8
+%22 = OpLabel
+%v1 = OpVariable %_ptr_Function_v4float Function
+%v2 = OpVariable %_ptr_Function_v4float Function
+%23 = OpLoad %v4float %BaseColor
+OpStore %v1 %23
+%24 = OpLoad %v4float %v1
+%25 = OpAccessChain %_ptr_Output_v4float %gl_FragData %int_0
+OpStore %25 %24
+%26 = OpLoad %v4float %BaseColor
+%27 = OpVectorTimesScalar %v4float %26 %float_0_5
+%28 = OpCopyObject %_ptr_Function_v4float %v2
+OpStore %28 %27
+%29 = OpLoad %v4float %28
+%30 = OpAccessChain %_ptr_Output_v4float %gl_FragData %int_1
+OpStore %30 %29
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndCheck<opt::LocalSingleBlockLoadStoreElimPass>(
+ assembly, assembly, false, true);
+}
+
+// TODO(greg-lunarg): Add tests to verify handling of these cases:
+//
+// Other target variable types
+// InBounds Access Chains
+// Check for correctness in the presence of function calls
+// Others?
+
+} // anonymous namespace
diff --git a/test/opt/optimizer_test.cpp b/test/opt/optimizer_test.cpp
new file mode 100644
index 00000000..ef0e9921
--- /dev/null
+++ b/test/opt/optimizer_test.cpp
@@ -0,0 +1,109 @@
+// Copyright (c) 2017 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <gmock/gmock.h>
+
+#include "spirv-tools/libspirv.hpp"
+#include "spirv-tools/optimizer.hpp"
+
+#include "pass_fixture.h"
+
+namespace {
+
+using spvtools::CreateNullPass;
+using spvtools::CreateStripDebugInfoPass;
+using spvtools::Optimizer;
+using spvtools::SpirvTools;
+using ::testing::Eq;
+
+TEST(Optimizer, CanRunNullPassWithDistinctInputOutputVectors) {
+ SpirvTools tools(SPV_ENV_UNIVERSAL_1_0);
+ std::vector<uint32_t> binary_in;
+ tools.Assemble("OpName %foo \"foo\"\n%foo = OpTypeVoid", &binary_in);
+
+ Optimizer opt(SPV_ENV_UNIVERSAL_1_0);
+ opt.RegisterPass(CreateNullPass());
+ std::vector<uint32_t> binary_out;
+ opt.Run(binary_in.data(), binary_in.size(), &binary_out);
+
+ std::string disassembly;
+ tools.Disassemble(binary_out.data(), binary_out.size(), &disassembly);
+ EXPECT_THAT(disassembly, Eq("OpName %foo \"foo\"\n%foo = OpTypeVoid\n"));
+}
+
+TEST(Optimizer, CanRunTransformingPassWithDistinctInputOutputVectors) {
+ SpirvTools tools(SPV_ENV_UNIVERSAL_1_0);
+ std::vector<uint32_t> binary_in;
+ tools.Assemble("OpName %foo \"foo\"\n%foo = OpTypeVoid", &binary_in);
+
+ Optimizer opt(SPV_ENV_UNIVERSAL_1_0);
+ opt.RegisterPass(CreateStripDebugInfoPass());
+ std::vector<uint32_t> binary_out;
+ opt.Run(binary_in.data(), binary_in.size(), &binary_out);
+
+ std::string disassembly;
+ tools.Disassemble(binary_out.data(), binary_out.size(), &disassembly);
+ EXPECT_THAT(disassembly, Eq("%void = OpTypeVoid\n"));
+}
+
+TEST(Optimizer, CanRunNullPassWithAliasedVectors) {
+ SpirvTools tools(SPV_ENV_UNIVERSAL_1_0);
+ std::vector<uint32_t> binary;
+ tools.Assemble("OpName %foo \"foo\"\n%foo = OpTypeVoid", &binary);
+
+ Optimizer opt(SPV_ENV_UNIVERSAL_1_0);
+ opt.RegisterPass(CreateNullPass());
+ opt.Run(binary.data(), binary.size(), &binary); // This is the key.
+
+ std::string disassembly;
+ tools.Disassemble(binary.data(), binary.size(), &disassembly);
+ EXPECT_THAT(disassembly, Eq("OpName %foo \"foo\"\n%foo = OpTypeVoid\n"));
+}
+
+TEST(Optimizer, CanRunNullPassWithAliasedVectorDataButDifferentSize) {
+ SpirvTools tools(SPV_ENV_UNIVERSAL_1_0);
+ std::vector<uint32_t> binary;
+ tools.Assemble("OpName %foo \"foo\"\n%foo = OpTypeVoid", &binary);
+
+ Optimizer opt(SPV_ENV_UNIVERSAL_1_0);
+ opt.RegisterPass(CreateNullPass());
+ auto orig_size = binary.size();
+ // Now change the size. Add a word that will be ignored
+ // by the optimizer.
+ binary.push_back(42);
+ EXPECT_THAT(orig_size + 1, Eq(binary.size()));
+ opt.Run(binary.data(), orig_size, &binary); // This is the key.
+ // The binary vector should have been rewritten.
+ EXPECT_THAT(binary.size(), Eq(orig_size));
+
+ std::string disassembly;
+ tools.Disassemble(binary.data(), binary.size(), &disassembly);
+ EXPECT_THAT(disassembly, Eq("OpName %foo \"foo\"\n%foo = OpTypeVoid\n"));
+}
+
+TEST(Optimizer, CanRunTransformingPassWithAliasedVectors) {
+ SpirvTools tools(SPV_ENV_UNIVERSAL_1_0);
+ std::vector<uint32_t> binary;
+ tools.Assemble("OpName %foo \"foo\"\n%foo = OpTypeVoid", &binary);
+
+ Optimizer opt(SPV_ENV_UNIVERSAL_1_0);
+ opt.RegisterPass(CreateStripDebugInfoPass());
+ opt.Run(binary.data(), binary.size(), &binary); // This is the key
+
+ std::string disassembly;
+ tools.Disassemble(binary.data(), binary.size(), &disassembly);
+ EXPECT_THAT(disassembly, Eq("%void = OpTypeVoid\n"));
+}
+
+} // namespace
diff --git a/test/opt/pass_fixture.h b/test/opt/pass_fixture.h
index 1b257a66..c5c45d95 100644
--- a/test/opt/pass_fixture.h
+++ b/test/opt/pass_fixture.h
@@ -43,15 +43,17 @@ class PassTest : public TestT {
PassTest()
: consumer_(nullptr),
tools_(SPV_ENV_UNIVERSAL_1_1),
- manager_(new opt::PassManager()) {}
+ manager_(new opt::PassManager()),
+ assemble_options_(SpirvTools::kDefaultAssembleOption),
+ disassemble_options_(SpirvTools::kDefaultDisassembleOption) {}
// Runs the given |pass| on the binary assembled from the |original|.
// Returns a tuple of the optimized binary and the boolean value returned
// from pass Process() function.
std::tuple<std::vector<uint32_t>, opt::Pass::Status> OptimizeToBinary(
opt::Pass* pass, const std::string& original, bool skip_nop) {
- std::unique_ptr<ir::Module> module =
- BuildModule(SPV_ENV_UNIVERSAL_1_1, consumer_, original);
+ std::unique_ptr<ir::Module> module = BuildModule(
+ SPV_ENV_UNIVERSAL_1_1, consumer_, original, assemble_options_);
EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
<< original << std::endl;
if (!module) {
@@ -88,7 +90,8 @@ class PassTest : public TestT {
std::tie(optimized_bin, status) = SinglePassRunToBinary<PassT>(
assembly, skip_nop, std::forward<Args>(args)...);
std::string optimized_asm;
- EXPECT_TRUE(tools_.Disassemble(optimized_bin, &optimized_asm))
+ EXPECT_TRUE(tools_.Disassemble(optimized_bin, &optimized_asm,
+ disassemble_options_))
<< "Disassembling failed for shader:\n"
<< assembly << std::endl;
return std::make_tuple(optimized_asm, status);
@@ -125,7 +128,8 @@ class PassTest : public TestT {
spvContextDestroy(context);
}
std::string optimized_asm;
- EXPECT_TRUE(tools_.Disassemble(optimized_bin, &optimized_asm))
+ EXPECT_TRUE(tools_.Disassemble(optimized_bin, &optimized_asm,
+ disassemble_options_))
<< "Disassembling failed for shader:\n"
<< original << std::endl;
EXPECT_EQ(expected, optimized_asm);
@@ -162,8 +166,8 @@ class PassTest : public TestT {
void RunAndCheck(const std::string& original, const std::string& expected) {
assert(manager_->NumPasses());
- std::unique_ptr<ir::Module> module =
- BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, original);
+ std::unique_ptr<ir::Module> module = BuildModule(
+ SPV_ENV_UNIVERSAL_1_1, nullptr, original, assemble_options_);
ASSERT_NE(nullptr, module);
manager_->Run(module.get());
@@ -172,14 +176,25 @@ class PassTest : public TestT {
module->ToBinary(&binary, /* skip_nop = */ false);
std::string optimized;
- EXPECT_TRUE(tools_.Disassemble(binary, &optimized));
+ EXPECT_TRUE(tools_.Disassemble(binary, &optimized,
+ disassemble_options_));
EXPECT_EQ(expected, optimized);
}
+ void SetAssembleOptions(uint32_t assemble_options) {
+ assemble_options_ = assemble_options;
+ }
+
+ void SetDisassembleOptions(uint32_t disassemble_options) {
+ disassemble_options_ = disassemble_options;
+ }
+
private:
MessageConsumer consumer_; // Message consumer.
SpirvTools tools_; // An instance for calling SPIRV-Tools functionalities.
std::unique_ptr<opt::PassManager> manager_; // The pass manager.
+ uint32_t assemble_options_;
+ uint32_t disassemble_options_;
};
} // namespace spvtools
diff --git a/test/opt/set_spec_const_default_value_test.cpp b/test/opt/set_spec_const_default_value_test.cpp
index a6eaab9e..e4edbc3c 100644
--- a/test/opt/set_spec_const_default_value_test.cpp
+++ b/test/opt/set_spec_const_default_value_test.cpp
@@ -23,6 +23,8 @@ using testing::Eq;
using SpecIdToValueStrMap =
opt::SetSpecConstantDefaultValuePass::SpecIdToValueStrMap;
+using SpecIdToValueBitPatternMap =
+ opt::SetSpecConstantDefaultValuePass::SpecIdToValueBitPatternMap;
struct DefaultValuesStringParsingTestCase {
const char* default_values_str;
@@ -40,7 +42,7 @@ TEST_P(DefaultValuesStringParsingTest, TestCase) {
tc.default_values_str);
if (tc.expect_success) {
EXPECT_NE(nullptr, actual_map);
- if (actual_map) EXPECT_THAT(*actual_map, Eq(tc.expected_map));
+ if (actual_map) { EXPECT_THAT(*actual_map, Eq(tc.expected_map)); }
} else {
EXPECT_EQ(nullptr, actual_map);
}
@@ -129,24 +131,25 @@ INSTANTIATE_TEST_CASE_P(
{"0x3p10:200", false, SpecIdToValueStrMap{}},
}));
-struct SetSpecConstantDefaultValueTestCase {
+struct SetSpecConstantDefaultValueInStringFormTestCase {
const char* code;
SpecIdToValueStrMap default_values;
const char* expected;
};
-using SetSpecConstantDefaultValueParamTest =
- PassTest<::testing::TestWithParam<SetSpecConstantDefaultValueTestCase>>;
+using SetSpecConstantDefaultValueInStringFormParamTest = PassTest<
+ ::testing::TestWithParam<SetSpecConstantDefaultValueInStringFormTestCase>>;
-TEST_P(SetSpecConstantDefaultValueParamTest, TestCase) {
+TEST_P(SetSpecConstantDefaultValueInStringFormParamTest, TestCase) {
const auto& tc = GetParam();
SinglePassRunAndCheck<opt::SetSpecConstantDefaultValuePass>(
tc.code, tc.expected, /* skip_nop = */ false, tc.default_values);
}
INSTANTIATE_TEST_CASE_P(
- ValidCases, SetSpecConstantDefaultValueParamTest,
- ::testing::ValuesIn(std::vector<SetSpecConstantDefaultValueTestCase>{
+ ValidCases, SetSpecConstantDefaultValueInStringFormParamTest,
+ ::testing::ValuesIn(std::vector<
+ SetSpecConstantDefaultValueInStringFormTestCase>{
// 0. Empty.
{"", SpecIdToValueStrMap{}, ""},
// 1. Empty with non-empty values to set.
@@ -439,8 +442,9 @@ INSTANTIATE_TEST_CASE_P(
}));
INSTANTIATE_TEST_CASE_P(
- InvalidCases, SetSpecConstantDefaultValueParamTest,
- ::testing::ValuesIn(std::vector<SetSpecConstantDefaultValueTestCase>{
+ InvalidCases, SetSpecConstantDefaultValueInStringFormParamTest,
+ ::testing::ValuesIn(std::vector<
+ SetSpecConstantDefaultValueInStringFormTestCase>{
// 0. Do not crash when decoration group is not used.
{
// code
@@ -468,7 +472,7 @@ INSTANTIATE_TEST_CASE_P(
"%int = OpTypeInt 32 1\n",
},
// 2. Do nothing when SpecId decoration is not attached to a
- // non-spec-contant instruction.
+ // non-spec-constant instruction.
{
// code
"OpDecorate %1 SpecId 100\n"
@@ -537,6 +541,519 @@ INSTANTIATE_TEST_CASE_P(
"%int = OpTypeInt 32 1\n"
"%int_100 = OpConstant %int 100\n",
},
+ // 6. Boolean type spec constant cannot be set with numeric values in
+ // string form. i.e. only 'true' and 'false' are acceptable for setting
+ // boolean type spec constants. Nothing should be done if numeric values
+ // in string form are provided.
+ {
+ // code
+ "OpDecorate %1 SpecId 100\n"
+ "OpDecorate %2 SpecId 101\n"
+ "OpDecorate %3 SpecId 102\n"
+ "OpDecorate %4 SpecId 103\n"
+ "OpDecorate %5 SpecId 104\n"
+ "OpDecorate %6 SpecId 105\n"
+ "%bool = OpTypeBool\n"
+ "%1 = OpSpecConstantTrue %bool\n"
+ "%2 = OpSpecConstantFalse %bool\n"
+ "%3 = OpSpecConstantTrue %bool\n"
+ "%4 = OpSpecConstantTrue %bool\n"
+ "%5 = OpSpecConstantTrue %bool\n"
+ "%6 = OpSpecConstantFalse %bool\n",
+ // default values
+ SpecIdToValueStrMap{{100, "0"},
+ {101, "1"},
+ {102, "0x0"},
+ {103, "0.0"},
+ {104, "-0.0"},
+ {105, "0x12345678"}},
+ // expected
+ "OpDecorate %1 SpecId 100\n"
+ "OpDecorate %2 SpecId 101\n"
+ "OpDecorate %3 SpecId 102\n"
+ "OpDecorate %4 SpecId 103\n"
+ "OpDecorate %5 SpecId 104\n"
+ "OpDecorate %6 SpecId 105\n"
+ "%bool = OpTypeBool\n"
+ "%1 = OpSpecConstantTrue %bool\n"
+ "%2 = OpSpecConstantFalse %bool\n"
+ "%3 = OpSpecConstantTrue %bool\n"
+ "%4 = OpSpecConstantTrue %bool\n"
+ "%5 = OpSpecConstantTrue %bool\n"
+ "%6 = OpSpecConstantFalse %bool\n",
+ },
+ }));
+
+struct SetSpecConstantDefaultValueInBitPatternFormTestCase {
+ const char* code;
+ SpecIdToValueBitPatternMap default_values;
+ const char* expected;
+};
+
+using SetSpecConstantDefaultValueInBitPatternFormParamTest =
+ PassTest<::testing::TestWithParam<
+ SetSpecConstantDefaultValueInBitPatternFormTestCase>>;
+
+TEST_P(SetSpecConstantDefaultValueInBitPatternFormParamTest, TestCase) {
+ const auto& tc = GetParam();
+ SinglePassRunAndCheck<opt::SetSpecConstantDefaultValuePass>(
+ tc.code, tc.expected, /* skip_nop = */ false, tc.default_values);
+}
+
+INSTANTIATE_TEST_CASE_P(
+ ValidCases, SetSpecConstantDefaultValueInBitPatternFormParamTest,
+ ::testing::ValuesIn(std::vector<
+ SetSpecConstantDefaultValueInBitPatternFormTestCase>{
+ // 0. Empty.
+ {"", SpecIdToValueBitPatternMap{}, ""},
+ // 1. Empty with non-empty values to set.
+ {"", SpecIdToValueBitPatternMap{{1, {100}}, {2, {200}}}, ""},
+ // 2. Baisc bool type.
+ {
+ // code
+ "OpDecorate %1 SpecId 100\n"
+ "OpDecorate %2 SpecId 101\n"
+ "%bool = OpTypeBool\n"
+ "%1 = OpSpecConstantTrue %bool\n"
+ "%2 = OpSpecConstantFalse %bool\n",
+ // default values
+ SpecIdToValueBitPatternMap{{100, {0x0}}, {101, {0x1}}},
+ // expected
+ "OpDecorate %1 SpecId 100\n"
+ "OpDecorate %2 SpecId 101\n"
+ "%bool = OpTypeBool\n"
+ "%1 = OpSpecConstantFalse %bool\n"
+ "%2 = OpSpecConstantTrue %bool\n",
+ },
+ // 3. 32-bit int type.
+ {
+ // code
+ "OpDecorate %1 SpecId 100\n"
+ "OpDecorate %2 SpecId 101\n"
+ "OpDecorate %3 SpecId 102\n"
+ "%int = OpTypeInt 32 1\n"
+ "%1 = OpSpecConstant %int 10\n"
+ "%2 = OpSpecConstant %int 11\n"
+ "%3 = OpSpecConstant %int 11\n",
+ // default values
+ SpecIdToValueBitPatternMap{
+ {100, {2147483647}}, {101, {0xffffffff}}, {102, {0xffffffd6}}},
+ // expected
+ "OpDecorate %1 SpecId 100\n"
+ "OpDecorate %2 SpecId 101\n"
+ "OpDecorate %3 SpecId 102\n"
+ "%int = OpTypeInt 32 1\n"
+ "%1 = OpSpecConstant %int 2147483647\n"
+ "%2 = OpSpecConstant %int -1\n"
+ "%3 = OpSpecConstant %int -42\n",
+ },
+ // 4. 64-bit uint type.
+ {
+ // code
+ "OpDecorate %1 SpecId 100\n"
+ "OpDecorate %2 SpecId 101\n"
+ "%ulong = OpTypeInt 64 0\n"
+ "%1 = OpSpecConstant %ulong 10\n"
+ "%2 = OpSpecConstant %ulong 11\n",
+ // default values
+ SpecIdToValueBitPatternMap{{100, {0xFFFFFFFE, 0xFFFFFFFF}},
+ {101, {0x100, 0x0}}},
+ // expected
+ "OpDecorate %1 SpecId 100\n"
+ "OpDecorate %2 SpecId 101\n"
+ "%ulong = OpTypeInt 64 0\n"
+ "%1 = OpSpecConstant %ulong 18446744073709551614\n"
+ "%2 = OpSpecConstant %ulong 256\n",
+ },
+ // 5. 32-bit float type.
+ {
+ // code
+ "OpDecorate %1 SpecId 101\n"
+ "OpDecorate %2 SpecId 102\n"
+ "%float = OpTypeFloat 32\n"
+ "%1 = OpSpecConstant %float 200\n"
+ "%2 = OpSpecConstant %float 201\n",
+ // default values
+ SpecIdToValueBitPatternMap{{101, {0xffffffff}},
+ {102, {0x40200000}}},
+ // expected
+ "OpDecorate %1 SpecId 101\n"
+ "OpDecorate %2 SpecId 102\n"
+ "%float = OpTypeFloat 32\n"
+ "%1 = OpSpecConstant %float -0x1.fffffep+128\n"
+ "%2 = OpSpecConstant %float 2.5\n",
+ },
+ // 6. 64-bit float type.
+ {
+ // code
+ "OpDecorate %1 SpecId 201\n"
+ "OpDecorate %2 SpecId 202\n"
+ "%double = OpTypeFloat 64\n"
+ "%1 = OpSpecConstant %double 3.14159265358979\n"
+ "%2 = OpSpecConstant %double 0.142857\n",
+ // default values
+ SpecIdToValueBitPatternMap{{201, {0xffffffff, 0x7fffffff}},
+ {202, {0x00000000, 0xc0404000}}},
+ // expected
+ "OpDecorate %1 SpecId 201\n"
+ "OpDecorate %2 SpecId 202\n"
+ "%double = OpTypeFloat 64\n"
+ "%1 = OpSpecConstant %double 0x1.fffffffffffffp+1024\n"
+ "%2 = OpSpecConstant %double -32.5\n",
+ },
+ // 7. SpecId not found, expect no modification.
+ {
+ // code
+ "OpDecorate %1 SpecId 201\n"
+ "%double = OpTypeFloat 64\n"
+ "%1 = OpSpecConstant %double 3.14159265358979\n",
+ // default values
+ SpecIdToValueBitPatternMap{{8888, {0x0}}},
+ // expected
+ "OpDecorate %1 SpecId 201\n"
+ "%double = OpTypeFloat 64\n"
+ "%1 = OpSpecConstant %double 3.14159265358979\n",
+ },
+ // 8. Multiple types of spec constants.
+ {
+ // code
+ "OpDecorate %1 SpecId 201\n"
+ "OpDecorate %2 SpecId 202\n"
+ "OpDecorate %3 SpecId 203\n"
+ "%bool = OpTypeBool\n"
+ "%int = OpTypeInt 32 1\n"
+ "%double = OpTypeFloat 64\n"
+ "%1 = OpSpecConstant %double 3.14159265358979\n"
+ "%2 = OpSpecConstant %int 1024\n"
+ "%3 = OpSpecConstantTrue %bool\n",
+ // default values
+ SpecIdToValueBitPatternMap{
+ {201, {0xffffffff, 0x7fffffff}},
+ {202, {0x00000800}},
+ {203, {0x0}},
+ },
+ // expected
+ "OpDecorate %1 SpecId 201\n"
+ "OpDecorate %2 SpecId 202\n"
+ "OpDecorate %3 SpecId 203\n"
+ "%bool = OpTypeBool\n"
+ "%int = OpTypeInt 32 1\n"
+ "%double = OpTypeFloat 64\n"
+ "%1 = OpSpecConstant %double 0x1.fffffffffffffp+1024\n"
+ "%2 = OpSpecConstant %int 2048\n"
+ "%3 = OpSpecConstantFalse %bool\n",
+ },
+ // 9. Ignore other decorations.
+ {
+ // code
+ "OpDecorate %1 ArrayStride 4\n"
+ "%int = OpTypeInt 32 1\n"
+ "%1 = OpSpecConstant %int 100\n",
+ // default values
+ SpecIdToValueBitPatternMap{{4, {0x7fffffff}}},
+ // expected
+ "OpDecorate %1 ArrayStride 4\n"
+ "%int = OpTypeInt 32 1\n"
+ "%1 = OpSpecConstant %int 100\n",
+ },
+ // 10. Distinguish from other decorations.
+ {
+ // code
+ "OpDecorate %1 SpecId 100\n"
+ "OpDecorate %1 ArrayStride 4\n"
+ "%int = OpTypeInt 32 1\n"
+ "%1 = OpSpecConstant %int 100\n",
+ // default values
+ SpecIdToValueBitPatternMap{{4, {0x7fffffff}}, {100, {0xffffffff}}},
+ // expected
+ "OpDecorate %1 SpecId 100\n"
+ "OpDecorate %1 ArrayStride 4\n"
+ "%int = OpTypeInt 32 1\n"
+ "%1 = OpSpecConstant %int -1\n",
+ },
+ // 11. Decorate through decoration group.
+ {
+ // code
+ "OpDecorate %1 SpecId 100\n"
+ "%1 = OpDecorationGroup\n"
+ "OpGroupDecorate %1 %2\n"
+ "%int = OpTypeInt 32 1\n"
+ "%2 = OpSpecConstant %int 100\n",
+ // default values
+ SpecIdToValueBitPatternMap{{100, {0x7fffffff}}},
+ // expected
+ "OpDecorate %1 SpecId 100\n"
+ "%1 = OpDecorationGroup\n"
+ "OpGroupDecorate %1 %2\n"
+ "%int = OpTypeInt 32 1\n"
+ "%2 = OpSpecConstant %int 2147483647\n",
+ },
+ // 12. Ignore other decorations in decoration group.
+ {
+ // code
+ "OpDecorate %1 ArrayStride 4\n"
+ "%1 = OpDecorationGroup\n"
+ "OpGroupDecorate %1 %2\n"
+ "%int = OpTypeInt 32 1\n"
+ "%2 = OpSpecConstant %int 100\n",
+ // default values
+ SpecIdToValueBitPatternMap{{4, {0x7fffffff}}},
+ // expected
+ "OpDecorate %1 ArrayStride 4\n"
+ "%1 = OpDecorationGroup\n"
+ "OpGroupDecorate %1 %2\n"
+ "%int = OpTypeInt 32 1\n"
+ "%2 = OpSpecConstant %int 100\n",
+ },
+ // 13. Distinguish from other decorations in decoration group.
+ {
+ // code
+ "OpDecorate %1 SpecId 100\n"
+ "OpDecorate %1 ArrayStride 4\n"
+ "%1 = OpDecorationGroup\n"
+ "OpGroupDecorate %1 %2\n"
+ "%int = OpTypeInt 32 1\n"
+ "%2 = OpSpecConstant %int 100\n",
+ // default values
+ SpecIdToValueBitPatternMap{{100, {0x7fffffff}}, {4, {0x00000001}}},
+ // expected
+ "OpDecorate %1 SpecId 100\n"
+ "OpDecorate %1 ArrayStride 4\n"
+ "%1 = OpDecorationGroup\n"
+ "OpGroupDecorate %1 %2\n"
+ "%int = OpTypeInt 32 1\n"
+ "%2 = OpSpecConstant %int 2147483647\n",
+ },
+ // 14. Unchanged bool default value
+ {
+ // code
+ "OpDecorate %1 SpecId 100\n"
+ "OpDecorate %2 SpecId 101\n"
+ "%bool = OpTypeBool\n"
+ "%1 = OpSpecConstantTrue %bool\n"
+ "%2 = OpSpecConstantFalse %bool\n",
+ // default values
+ SpecIdToValueBitPatternMap{{100, {0x1}}, {101, {0x0}}},
+ // expected
+ "OpDecorate %1 SpecId 100\n"
+ "OpDecorate %2 SpecId 101\n"
+ "%bool = OpTypeBool\n"
+ "%1 = OpSpecConstantTrue %bool\n"
+ "%2 = OpSpecConstantFalse %bool\n",
+ },
+ // 15. Unchanged int default values
+ {
+ // code
+ "OpDecorate %1 SpecId 100\n"
+ "OpDecorate %2 SpecId 101\n"
+ "%int = OpTypeInt 32 1\n"
+ "%ulong = OpTypeInt 64 0\n"
+ "%1 = OpSpecConstant %int 10\n"
+ "%2 = OpSpecConstant %ulong 11\n",
+ // default values
+ SpecIdToValueBitPatternMap{{100, {10}}, {101, {11, 0}}},
+ // expected
+ "OpDecorate %1 SpecId 100\n"
+ "OpDecorate %2 SpecId 101\n"
+ "%int = OpTypeInt 32 1\n"
+ "%ulong = OpTypeInt 64 0\n"
+ "%1 = OpSpecConstant %int 10\n"
+ "%2 = OpSpecConstant %ulong 11\n",
+ },
+ // 16. Unchanged float default values
+ {
+ // code
+ "OpDecorate %1 SpecId 201\n"
+ "OpDecorate %2 SpecId 202\n"
+ "%float = OpTypeFloat 32\n"
+ "%double = OpTypeFloat 64\n"
+ "%1 = OpSpecConstant %float 3.25\n"
+ "%2 = OpSpecConstant %double 1.25\n",
+ // default values
+ SpecIdToValueBitPatternMap{{201, {0x40500000}},
+ {202, {0x00000000, 0x3ff40000}}},
+ // expected
+ "OpDecorate %1 SpecId 201\n"
+ "OpDecorate %2 SpecId 202\n"
+ "%float = OpTypeFloat 32\n"
+ "%double = OpTypeFloat 64\n"
+ "%1 = OpSpecConstant %float 3.25\n"
+ "%2 = OpSpecConstant %double 1.25\n",
+ },
+ // 17. OpGroupDecorate may have multiple target ids defined by the same
+ // eligible spec constant
+ {
+ // code
+ "OpDecorate %1 SpecId 100\n"
+ "%1 = OpDecorationGroup\n"
+ "OpGroupDecorate %1 %2 %2 %2\n"
+ "%int = OpTypeInt 32 1\n"
+ "%2 = OpSpecConstant %int 100\n",
+ // default values
+ SpecIdToValueBitPatternMap{{100, {0xffffffff}}},
+ // expected
+ "OpDecorate %1 SpecId 100\n"
+ "%1 = OpDecorationGroup\n"
+ "OpGroupDecorate %1 %2 %2 %2\n"
+ "%int = OpTypeInt 32 1\n"
+ "%2 = OpSpecConstant %int -1\n",
+ },
+ // 18. For Boolean type spec constants,if any word in the bit pattern
+ // is not zero, it can be considered as a 'true', otherwise, it can be
+ // considered as a 'false'.
+ {
+ // code
+ "OpDecorate %1 SpecId 100\n"
+ "OpDecorate %2 SpecId 101\n"
+ "OpDecorate %3 SpecId 102\n"
+ "%bool = OpTypeBool\n"
+ "%1 = OpSpecConstantTrue %bool\n"
+ "%2 = OpSpecConstantFalse %bool\n"
+ "%3 = OpSpecConstantFalse %bool\n",
+ // default values
+ SpecIdToValueBitPatternMap{
+ {100, {0x0, 0x0, 0x0, 0x0}},
+ {101, {0x10101010}},
+ {102, {0x0, 0x0, 0x0, 0x2}},
+ },
+ // expected
+ "OpDecorate %1 SpecId 100\n"
+ "OpDecorate %2 SpecId 101\n"
+ "OpDecorate %3 SpecId 102\n"
+ "%bool = OpTypeBool\n"
+ "%1 = OpSpecConstantFalse %bool\n"
+ "%2 = OpSpecConstantTrue %bool\n"
+ "%3 = OpSpecConstantTrue %bool\n",
+ },
+ }));
+
+INSTANTIATE_TEST_CASE_P(
+ InvalidCases, SetSpecConstantDefaultValueInBitPatternFormParamTest,
+ ::testing::ValuesIn(std::vector<
+ SetSpecConstantDefaultValueInBitPatternFormTestCase>{
+ // 0. Do not crash when decoration group is not used.
+ {
+ // code
+ "OpDecorate %1 SpecId 100\n"
+ "%1 = OpDecorationGroup\n"
+ "%int = OpTypeInt 32 1\n"
+ "%3 = OpSpecConstant %int 100\n",
+ // default values
+ SpecIdToValueBitPatternMap{{100, {0x7fffffff}}},
+ // expected
+ "OpDecorate %1 SpecId 100\n"
+ "%1 = OpDecorationGroup\n"
+ "%int = OpTypeInt 32 1\n"
+ "%3 = OpSpecConstant %int 100\n",
+ },
+ // 1. Do not crash when target does not exist.
+ {
+ // code
+ "OpDecorate %1 SpecId 100\n"
+ "%int = OpTypeInt 32 1\n",
+ // default values
+ SpecIdToValueBitPatternMap{{100, {0x7fffffff}}},
+ // expected
+ "OpDecorate %1 SpecId 100\n"
+ "%int = OpTypeInt 32 1\n",
+ },
+ // 2. Do nothing when SpecId decoration is not attached to a
+ // non-spec-constant instruction.
+ {
+ // code
+ "OpDecorate %1 SpecId 100\n"
+ "%int = OpTypeInt 32 1\n"
+ "%int_101 = OpConstant %int 101\n",
+ // default values
+ SpecIdToValueBitPatternMap{{100, {0x7fffffff}}},
+ // expected
+ "OpDecorate %1 SpecId 100\n"
+ "%int = OpTypeInt 32 1\n"
+ "%int_101 = OpConstant %int 101\n",
+ },
+ // 3. Do nothing when SpecId decoration is not attached to a
+ // OpSpecConstant{|True|False} instruction.
+ {
+ // code
+ "OpDecorate %1 SpecId 100\n"
+ "%int = OpTypeInt 32 1\n"
+ "%3 = OpSpecConstant %int 101\n"
+ "%1 = OpSpecConstantOp %int IAdd %3 %3\n",
+ // default values
+ SpecIdToValueBitPatternMap{{100, {0x7fffffff}}},
+ // expected
+ "OpDecorate %1 SpecId 100\n"
+ "%int = OpTypeInt 32 1\n"
+ "%3 = OpSpecConstant %int 101\n"
+ "%1 = OpSpecConstantOp %int IAdd %3 %3\n",
+ },
+ // 4. Do not crash and do nothing when SpecId decoration is applied to
+ // multiple spec constants.
+ {
+ // code
+ "OpDecorate %1 SpecId 100\n"
+ "%1 = OpDecorationGroup\n"
+ "OpGroupDecorate %1 %2 %3 %4\n"
+ "%int = OpTypeInt 32 1\n"
+ "%2 = OpSpecConstant %int 100\n"
+ "%3 = OpSpecConstant %int 200\n"
+ "%4 = OpSpecConstant %int 300\n",
+ // default values
+ SpecIdToValueBitPatternMap{{100, {0xffffffff}}},
+ // expected
+ "OpDecorate %1 SpecId 100\n"
+ "%1 = OpDecorationGroup\n"
+ "OpGroupDecorate %1 %2 %3 %4\n"
+ "%int = OpTypeInt 32 1\n"
+ "%2 = OpSpecConstant %int 100\n"
+ "%3 = OpSpecConstant %int 200\n"
+ "%4 = OpSpecConstant %int 300\n",
+ },
+ // 5. Do not crash and do nothing when SpecId decoration is attached to
+ // non-spec-constants (invalid case).
+ {
+ // code
+ "OpDecorate %1 SpecId 100\n"
+ "%1 = OpDecorationGroup\n"
+ "OpGroupDecorate %1 %2\n"
+ "%int = OpTypeInt 32 1\n"
+ "%int_100 = OpConstant %int 100\n",
+ // default values
+ SpecIdToValueBitPatternMap{{100, {0xffffffff}}},
+ // expected
+ "OpDecorate %1 SpecId 100\n"
+ "%1 = OpDecorationGroup\n"
+ "OpGroupDecorate %1 %2\n"
+ "%int = OpTypeInt 32 1\n"
+ "%int_100 = OpConstant %int 100\n",
+ },
+ // 6. Incompatible input bit pattern with the type. Nothing should be
+ // done in such a case.
+ {
+ // code
+ "OpDecorate %1 SpecId 100\n"
+ "OpDecorate %2 SpecId 101\n"
+ "OpDecorate %3 SpecId 102\n"
+ "%int = OpTypeInt 32 1\n"
+ "%ulong = OpTypeInt 64 0\n"
+ "%double = OpTypeFloat 64\n"
+ "%1 = OpSpecConstant %int 100\n"
+ "%2 = OpSpecConstant %ulong 200\n"
+ "%3 = OpSpecConstant %double 3.1415926\n",
+ // default values
+ SpecIdToValueBitPatternMap{
+ {100, {10, 0}}, {101, {11}}, {102, {0xffffffff}}},
+ // expected
+ "OpDecorate %1 SpecId 100\n"
+ "OpDecorate %2 SpecId 101\n"
+ "OpDecorate %3 SpecId 102\n"
+ "%int = OpTypeInt 32 1\n"
+ "%ulong = OpTypeInt 64 0\n"
+ "%double = OpTypeFloat 64\n"
+ "%1 = OpSpecConstant %int 100\n"
+ "%2 = OpSpecConstant %ulong 200\n"
+ "%3 = OpSpecConstant %double 3.1415926\n",
+ },
}));
} // anonymous namespace
diff --git a/test/preserve_numeric_ids_test.cpp b/test/preserve_numeric_ids_test.cpp
new file mode 100644
index 00000000..b8af9648
--- /dev/null
+++ b/test/preserve_numeric_ids_test.cpp
@@ -0,0 +1,158 @@
+// Copyright (c) 2017 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Tests for unique type declaration rules validator.
+
+#include <string>
+
+#include "source/text.h"
+#include "source/text_handler.h"
+#include "test_fixture.h"
+
+namespace {
+
+using spvtest::ScopedContext;
+
+// Converts code to binary and then back to text.
+spv_result_t ToBinaryAndBack(
+ const std::string& before, std::string* after,
+ uint32_t text_to_binary_options = SPV_TEXT_TO_BINARY_OPTION_NONE,
+ uint32_t binary_to_text_options = SPV_BINARY_TO_TEXT_OPTION_NONE,
+ spv_target_env env = SPV_ENV_UNIVERSAL_1_0) {
+ ScopedContext ctx(env);
+ spv_binary binary;
+ spv_text text;
+
+ spv_result_t result = spvTextToBinaryWithOptions(
+ ctx.context, before.c_str(), before.size(), text_to_binary_options,
+ &binary, nullptr);
+ if (result != SPV_SUCCESS) {
+ return result;
+ }
+
+ result = spvBinaryToText(
+ ctx.context, binary->code, binary->wordCount, binary_to_text_options,
+ &text, nullptr);
+ if (result != SPV_SUCCESS) {
+ return result;
+ }
+
+ *after = std::string(text->str, text->length);
+
+ spvBinaryDestroy(binary);
+ spvTextDestroy(text);
+
+ return SPV_SUCCESS;
+}
+
+TEST(ToBinaryAndBack, DontPreserveNumericIds) {
+ const std::string before =
+R"(OpCapability Addresses
+OpCapability Kernel
+OpCapability GenericPointer
+OpCapability Linkage
+OpMemoryModel Physical32 OpenCL
+%i32 = OpTypeInt 32 1
+%u32 = OpTypeInt 32 0
+%f32 = OpTypeFloat 32
+%200 = OpTypeVoid
+%300 = OpTypeFunction %200
+%main = OpFunction %200 None %300
+%entry = OpLabel
+%100 = OpConstant %u32 100
+%1 = OpConstant %u32 200
+%2 = OpConstant %u32 300
+OpReturn
+OpFunctionEnd
+)";
+
+ const std::string expected =
+R"(OpCapability Addresses
+OpCapability Kernel
+OpCapability GenericPointer
+OpCapability Linkage
+OpMemoryModel Physical32 OpenCL
+%1 = OpTypeInt 32 1
+%2 = OpTypeInt 32 0
+%3 = OpTypeFloat 32
+%4 = OpTypeVoid
+%5 = OpTypeFunction %4
+%6 = OpFunction %4 None %5
+%7 = OpLabel
+%8 = OpConstant %2 100
+%9 = OpConstant %2 200
+%10 = OpConstant %2 300
+OpReturn
+OpFunctionEnd
+)";
+
+ std::string after;
+ EXPECT_EQ(SPV_SUCCESS, ToBinaryAndBack(before, &after,
+ SPV_TEXT_TO_BINARY_OPTION_NONE,
+ SPV_BINARY_TO_TEXT_OPTION_NO_HEADER));
+
+ EXPECT_EQ(expected, after);
+}
+
+TEST(TextHandler, PreserveNumericIds) {
+ const std::string before =
+R"(OpCapability Addresses
+OpCapability Kernel
+OpCapability GenericPointer
+OpCapability Linkage
+OpMemoryModel Physical32 OpenCL
+%i32 = OpTypeInt 32 1
+%u32 = OpTypeInt 32 0
+%f32 = OpTypeFloat 32
+%200 = OpTypeVoid
+%300 = OpTypeFunction %200
+%main = OpFunction %200 None %300
+%entry = OpLabel
+%100 = OpConstant %u32 100
+%1 = OpConstant %u32 200
+%2 = OpConstant %u32 300
+OpReturn
+OpFunctionEnd
+)";
+
+ const std::string expected =
+R"(OpCapability Addresses
+OpCapability Kernel
+OpCapability GenericPointer
+OpCapability Linkage
+OpMemoryModel Physical32 OpenCL
+%3 = OpTypeInt 32 1
+%4 = OpTypeInt 32 0
+%5 = OpTypeFloat 32
+%200 = OpTypeVoid
+%300 = OpTypeFunction %200
+%6 = OpFunction %200 None %300
+%7 = OpLabel
+%100 = OpConstant %4 100
+%1 = OpConstant %4 200
+%2 = OpConstant %4 300
+OpReturn
+OpFunctionEnd
+)";
+
+ std::string after;
+ EXPECT_EQ(SPV_SUCCESS,
+ ToBinaryAndBack(before, &after,
+ SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS,
+ SPV_BINARY_TO_TEXT_OPTION_NO_HEADER));
+
+ EXPECT_EQ(expected, after);
+}
+
+} // namespace
diff --git a/test/scripts/test_compact_ids.py b/test/scripts/test_compact_ids.py
new file mode 100644
index 00000000..b9b5b1bc
--- /dev/null
+++ b/test/scripts/test_compact_ids.py
@@ -0,0 +1,102 @@
+#!/usr/bin/env python
+# Copyright (c) 2017 Google Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests correctness of opt pass tools/opt --compact-ids."""
+
+from __future__ import print_function
+
+import os.path
+import sys
+import tempfile
+
+def test_spirv_file(path, temp_dir):
+ optimized_spv_path = os.path.join(temp_dir, 'optimized.spv')
+ optimized_dis_path = os.path.join(temp_dir, 'optimized.dis')
+ converted_spv_path = os.path.join(temp_dir, 'converted.spv')
+ converted_dis_path = os.path.join(temp_dir, 'converted.dis')
+
+ os.system('tools/spirv-opt ' + path + ' -o ' + optimized_spv_path +
+ ' --compact-ids')
+ os.system('tools/spirv-dis ' + optimized_spv_path + ' -o ' +
+ optimized_dis_path)
+
+ os.system('tools/spirv-dis ' + path + ' -o ' + converted_dis_path)
+ os.system('tools/spirv-as ' + converted_dis_path + ' -o ' +
+ converted_spv_path)
+ os.system('tools/spirv-dis ' + converted_spv_path + ' -o ' +
+ converted_dis_path)
+
+ with open(converted_dis_path, 'r') as f:
+ converted_dis = f.readlines()[3:]
+
+ with open(optimized_dis_path, 'r') as f:
+ optimized_dis = f.readlines()[3:]
+
+ return converted_dis == optimized_dis
+
+def print_usage():
+ template= \
+"""{script} tests correctness of opt pass tools/opt --compact-ids
+
+USAGE: python {script} [<spirv_files>]
+
+Requires tools/spirv-dis, tools/spirv-as and tools/spirv-opt to be in path
+(call the script from the SPIRV-Tools build output directory).
+
+TIP: In order to test all .spv files under current dir use
+find <path> -name "*.spv" -print0 | xargs -0 -s 2000000 python {script}
+"""
+ print(template.format(script=sys.argv[0]));
+
+def main():
+ if not os.path.isfile('tools/spirv-dis'):
+ print('error: tools/spirv-dis not found')
+ print_usage()
+ exit(1)
+
+ if not os.path.isfile('tools/spirv-as'):
+ print('error: tools/spirv-as not found')
+ print_usage()
+ exit(1)
+
+ if not os.path.isfile('tools/spirv-opt'):
+ print('error: tools/spirv-opt not found')
+ print_usage()
+ exit(1)
+
+ paths = sys.argv[1:]
+ if not paths:
+ print_usage()
+
+ num_failed = 0
+
+ temp_dir = tempfile.mkdtemp()
+
+ for path in paths:
+ success = test_spirv_file(path, temp_dir)
+ if not success:
+ print('Test failed for ' + path)
+ num_failed += 1
+
+ print('Tested ' + str(len(paths)) + ' files')
+
+ if num_failed:
+ print(str(num_failed) + ' tests failed')
+ exit(1)
+ else:
+ print('All tests successful')
+ exit(0)
+
+if __name__ == '__main__':
+ main()
diff --git a/test/stats/CMakeLists.txt b/test/stats/CMakeLists.txt
new file mode 100644
index 00000000..20f05fd0
--- /dev/null
+++ b/test/stats/CMakeLists.txt
@@ -0,0 +1,31 @@
+# Copyright (c) 2017 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set(VAL_TEST_COMMON_SRCS
+ ${CMAKE_CURRENT_SOURCE_DIR}/../test_fixture.h
+ ${CMAKE_CURRENT_SOURCE_DIR}/../unit_spirv.h
+)
+
+add_spvtools_unittest(TARGET stats_aggregate
+ SRCS stats_aggregate_test.cpp
+ ${VAL_TEST_COMMON_SRCS}
+ LIBS ${SPIRV_TOOLS}
+)
+
+add_spvtools_unittest(TARGET stats_analyzer
+ SRCS stats_analyzer_test.cpp
+ ${CMAKE_CURRENT_SOURCE_DIR}/../../tools/stats/stats_analyzer.cpp
+ ${VAL_TEST_COMMON_SRCS}
+ LIBS ${SPIRV_TOOLS}
+)
diff --git a/test/stats/stats_aggregate_test.cpp b/test/stats/stats_aggregate_test.cpp
new file mode 100644
index 00000000..43026e80
--- /dev/null
+++ b/test/stats/stats_aggregate_test.cpp
@@ -0,0 +1,436 @@
+// Copyright (c) 2017 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Tests for unique type declaration rules validator.
+
+#include <string>
+
+#include "source/spirv_stats.h"
+#include "test_fixture.h"
+#include "unit_spirv.h"
+
+namespace {
+
+using libspirv::SpirvStats;
+using spvtest::ScopedContext;
+
+void DiagnosticsMessageHandler(spv_message_level_t level, const char*,
+ const spv_position_t& position,
+ const char* message) {
+ switch (level) {
+ case SPV_MSG_FATAL:
+ case SPV_MSG_INTERNAL_ERROR:
+ case SPV_MSG_ERROR:
+ std::cerr << "error: " << position.index << ": " << message
+ << std::endl;
+ break;
+ case SPV_MSG_WARNING:
+ std::cout << "warning: " << position.index << ": " << message
+ << std::endl;
+ break;
+ case SPV_MSG_INFO:
+ std::cout << "info: " << position.index << ": " << message << std::endl;
+ break;
+ default:
+ break;
+ }
+}
+
+// Calls libspirv::AggregateStats for binary compiled from |code|.
+void CompileAndAggregateStats(const std::string& code, SpirvStats* stats,
+ spv_target_env env = SPV_ENV_UNIVERSAL_1_1) {
+ ScopedContext ctx(env);
+ SetContextMessageConsumer(ctx.context, DiagnosticsMessageHandler);
+ spv_binary binary;
+ ASSERT_EQ(SPV_SUCCESS, spvTextToBinary(
+ ctx.context, code.c_str(), code.size(), &binary, nullptr));
+
+ ASSERT_EQ(SPV_SUCCESS, AggregateStats(*ctx.context, binary->code,
+ binary->wordCount, nullptr, stats));
+ spvBinaryDestroy(binary);
+}
+
+TEST(AggregateStats, CapabilityHistogram) {
+ const std::string code1 = R"(
+OpCapability Addresses
+OpCapability Kernel
+OpCapability GenericPointer
+OpCapability Linkage
+OpMemoryModel Physical32 OpenCL
+)";
+
+ const std::string code2 = R"(
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+)";
+
+ SpirvStats stats;
+
+ CompileAndAggregateStats(code1, &stats);
+ EXPECT_EQ(4u, stats.capability_hist.size());
+ EXPECT_EQ(0u, stats.capability_hist.count(SpvCapabilityShader));
+ EXPECT_EQ(1u, stats.capability_hist.at(SpvCapabilityAddresses));
+ EXPECT_EQ(1u, stats.capability_hist.at(SpvCapabilityKernel));
+ EXPECT_EQ(1u, stats.capability_hist.at(SpvCapabilityGenericPointer));
+ EXPECT_EQ(1u, stats.capability_hist.at(SpvCapabilityLinkage));
+
+ CompileAndAggregateStats(code2, &stats);
+ EXPECT_EQ(5u, stats.capability_hist.size());
+ EXPECT_EQ(1u, stats.capability_hist.at(SpvCapabilityShader));
+ EXPECT_EQ(1u, stats.capability_hist.at(SpvCapabilityAddresses));
+ EXPECT_EQ(1u, stats.capability_hist.at(SpvCapabilityKernel));
+ EXPECT_EQ(1u, stats.capability_hist.at(SpvCapabilityGenericPointer));
+ EXPECT_EQ(2u, stats.capability_hist.at(SpvCapabilityLinkage));
+
+ CompileAndAggregateStats(code1, &stats);
+ EXPECT_EQ(5u, stats.capability_hist.size());
+ EXPECT_EQ(1u, stats.capability_hist.at(SpvCapabilityShader));
+ EXPECT_EQ(2u, stats.capability_hist.at(SpvCapabilityAddresses));
+ EXPECT_EQ(2u, stats.capability_hist.at(SpvCapabilityKernel));
+ EXPECT_EQ(2u, stats.capability_hist.at(SpvCapabilityGenericPointer));
+ EXPECT_EQ(3u, stats.capability_hist.at(SpvCapabilityLinkage));
+
+ CompileAndAggregateStats(code2, &stats);
+ EXPECT_EQ(5u, stats.capability_hist.size());
+ EXPECT_EQ(2u, stats.capability_hist.at(SpvCapabilityShader));
+ EXPECT_EQ(2u, stats.capability_hist.at(SpvCapabilityAddresses));
+ EXPECT_EQ(2u, stats.capability_hist.at(SpvCapabilityKernel));
+ EXPECT_EQ(2u, stats.capability_hist.at(SpvCapabilityGenericPointer));
+ EXPECT_EQ(4u, stats.capability_hist.at(SpvCapabilityLinkage));
+}
+
+TEST(AggregateStats, ExtensionHistogram) {
+ const std::string code1 = R"(
+OpCapability Addresses
+OpCapability Kernel
+OpCapability GenericPointer
+OpCapability Linkage
+OpExtension "SPV_KHR_16bit_storage"
+OpMemoryModel Physical32 OpenCL
+)";
+
+ const std::string code2 = R"(
+OpCapability Shader
+OpCapability Linkage
+OpExtension "SPV_NV_viewport_array2"
+OpExtension "greatest_extension_ever"
+OpMemoryModel Logical GLSL450
+)";
+
+ SpirvStats stats;
+
+ CompileAndAggregateStats(code1, &stats);
+ EXPECT_EQ(1u, stats.extension_hist.size());
+ EXPECT_EQ(0u, stats.extension_hist.count("SPV_NV_viewport_array2"));
+ EXPECT_EQ(1u, stats.extension_hist.at("SPV_KHR_16bit_storage"));
+
+ CompileAndAggregateStats(code2, &stats);
+ EXPECT_EQ(3u, stats.extension_hist.size());
+ EXPECT_EQ(1u, stats.extension_hist.at("SPV_NV_viewport_array2"));
+ EXPECT_EQ(1u, stats.extension_hist.at("SPV_KHR_16bit_storage"));
+ EXPECT_EQ(1u, stats.extension_hist.at("greatest_extension_ever"));
+
+ CompileAndAggregateStats(code1, &stats);
+ EXPECT_EQ(3u, stats.extension_hist.size());
+ EXPECT_EQ(1u, stats.extension_hist.at("SPV_NV_viewport_array2"));
+ EXPECT_EQ(2u, stats.extension_hist.at("SPV_KHR_16bit_storage"));
+ EXPECT_EQ(1u, stats.extension_hist.at("greatest_extension_ever"));
+
+ CompileAndAggregateStats(code2, &stats);
+ EXPECT_EQ(3u, stats.extension_hist.size());
+ EXPECT_EQ(2u, stats.extension_hist.at("SPV_NV_viewport_array2"));
+ EXPECT_EQ(2u, stats.extension_hist.at("SPV_KHR_16bit_storage"));
+ EXPECT_EQ(2u, stats.extension_hist.at("greatest_extension_ever"));
+}
+
+TEST(AggregateStats, VersionHistogram) {
+ const std::string code1 = R"(
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+)";
+
+ SpirvStats stats;
+
+ CompileAndAggregateStats(code1, &stats);
+ EXPECT_EQ(1u, stats.version_hist.size());
+ EXPECT_EQ(1u, stats.version_hist.at(0x00010100));
+
+ CompileAndAggregateStats(code1, &stats, SPV_ENV_UNIVERSAL_1_0);
+ EXPECT_EQ(2u, stats.version_hist.size());
+ EXPECT_EQ(1u, stats.version_hist.at(0x00010100));
+ EXPECT_EQ(1u, stats.version_hist.at(0x00010000));
+
+ CompileAndAggregateStats(code1, &stats);
+ EXPECT_EQ(2u, stats.version_hist.size());
+ EXPECT_EQ(2u, stats.version_hist.at(0x00010100));
+ EXPECT_EQ(1u, stats.version_hist.at(0x00010000));
+
+ CompileAndAggregateStats(code1, &stats, SPV_ENV_UNIVERSAL_1_0);
+ EXPECT_EQ(2u, stats.version_hist.size());
+ EXPECT_EQ(2u, stats.version_hist.at(0x00010100));
+ EXPECT_EQ(2u, stats.version_hist.at(0x00010000));
+}
+
+TEST(AggregateStats, GeneratorHistogram) {
+ const std::string code1 = R"(
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+)";
+
+ const uint32_t kGeneratorKhronosAssembler =
+ SPV_GENERATOR_KHRONOS_ASSEMBLER << 16;
+
+ SpirvStats stats;
+
+ CompileAndAggregateStats(code1, &stats);
+ EXPECT_EQ(1u, stats.generator_hist.size());
+ EXPECT_EQ(1u, stats.generator_hist.at(kGeneratorKhronosAssembler));
+
+ CompileAndAggregateStats(code1, &stats);
+ EXPECT_EQ(1u, stats.generator_hist.size());
+ EXPECT_EQ(2u, stats.generator_hist.at(kGeneratorKhronosAssembler));
+}
+
+TEST(AggregateStats, OpcodeHistogram) {
+ const std::string code1 = R"(
+OpCapability Addresses
+OpCapability Kernel
+OpCapability Int64
+OpCapability Linkage
+OpMemoryModel Physical32 OpenCL
+%u64 = OpTypeInt 64 0
+%u32 = OpTypeInt 32 0
+%f32 = OpTypeFloat 32
+)";
+
+ const std::string code2 = R"(
+OpCapability Shader
+OpCapability Linkage
+OpExtension "SPV_NV_viewport_array2"
+OpMemoryModel Logical GLSL450
+)";
+
+ SpirvStats stats;
+
+ CompileAndAggregateStats(code1, &stats);
+ EXPECT_EQ(4u, stats.opcode_hist.size());
+ EXPECT_EQ(4u, stats.opcode_hist.at(SpvOpCapability));
+ EXPECT_EQ(1u, stats.opcode_hist.at(SpvOpMemoryModel));
+ EXPECT_EQ(2u, stats.opcode_hist.at(SpvOpTypeInt));
+ EXPECT_EQ(1u, stats.opcode_hist.at(SpvOpTypeFloat));
+
+ CompileAndAggregateStats(code2, &stats);
+ EXPECT_EQ(5u, stats.opcode_hist.size());
+ EXPECT_EQ(6u, stats.opcode_hist.at(SpvOpCapability));
+ EXPECT_EQ(2u, stats.opcode_hist.at(SpvOpMemoryModel));
+ EXPECT_EQ(2u, stats.opcode_hist.at(SpvOpTypeInt));
+ EXPECT_EQ(1u, stats.opcode_hist.at(SpvOpTypeFloat));
+ EXPECT_EQ(1u, stats.opcode_hist.at(SpvOpExtension));
+
+ CompileAndAggregateStats(code1, &stats);
+ EXPECT_EQ(5u, stats.opcode_hist.size());
+ EXPECT_EQ(10u, stats.opcode_hist.at(SpvOpCapability));
+ EXPECT_EQ(3u, stats.opcode_hist.at(SpvOpMemoryModel));
+ EXPECT_EQ(4u, stats.opcode_hist.at(SpvOpTypeInt));
+ EXPECT_EQ(2u, stats.opcode_hist.at(SpvOpTypeFloat));
+ EXPECT_EQ(1u, stats.opcode_hist.at(SpvOpExtension));
+
+ CompileAndAggregateStats(code2, &stats);
+ EXPECT_EQ(5u, stats.opcode_hist.size());
+ EXPECT_EQ(12u, stats.opcode_hist.at(SpvOpCapability));
+ EXPECT_EQ(4u, stats.opcode_hist.at(SpvOpMemoryModel));
+ EXPECT_EQ(4u, stats.opcode_hist.at(SpvOpTypeInt));
+ EXPECT_EQ(2u, stats.opcode_hist.at(SpvOpTypeFloat));
+ EXPECT_EQ(2u, stats.opcode_hist.at(SpvOpExtension));
+}
+
+TEST(AggregateStats, OpcodeMarkovHistogram) {
+ const std::string code1 = R"(
+OpCapability Shader
+OpCapability Linkage
+OpExtension "SPV_NV_viewport_array2"
+OpMemoryModel Logical GLSL450
+)";
+
+ const std::string code2 = R"(
+OpCapability Addresses
+OpCapability Kernel
+OpCapability Int64
+OpCapability Linkage
+OpMemoryModel Physical32 OpenCL
+%u64 = OpTypeInt 64 0
+%u32 = OpTypeInt 32 0
+%f32 = OpTypeFloat 32
+)";
+
+ SpirvStats stats;
+ stats.opcode_markov_hist.resize(2);
+
+ CompileAndAggregateStats(code1, &stats);
+ ASSERT_EQ(2u, stats.opcode_markov_hist.size());
+ EXPECT_EQ(2u, stats.opcode_markov_hist[0].size());
+ EXPECT_EQ(2u, stats.opcode_markov_hist[0].at(SpvOpCapability).size());
+ EXPECT_EQ(1u, stats.opcode_markov_hist[0].at(SpvOpExtension).size());
+ EXPECT_EQ(
+ 1u, stats.opcode_markov_hist[0].at(SpvOpCapability).at(SpvOpCapability));
+ EXPECT_EQ(
+ 1u, stats.opcode_markov_hist[0].at(SpvOpCapability).at(SpvOpExtension));
+ EXPECT_EQ(
+ 1u, stats.opcode_markov_hist[0].at(SpvOpExtension).at(SpvOpMemoryModel));
+
+ EXPECT_EQ(1u, stats.opcode_markov_hist[1].size());
+ EXPECT_EQ(2u, stats.opcode_markov_hist[1].at(SpvOpCapability).size());
+ EXPECT_EQ(
+ 1u, stats.opcode_markov_hist[1].at(SpvOpCapability).at(SpvOpExtension));
+ EXPECT_EQ(
+ 1u, stats.opcode_markov_hist[1].at(SpvOpCapability).at(SpvOpMemoryModel));
+
+ CompileAndAggregateStats(code2, &stats);
+ ASSERT_EQ(2u, stats.opcode_markov_hist.size());
+ EXPECT_EQ(4u, stats.opcode_markov_hist[0].size());
+ EXPECT_EQ(3u, stats.opcode_markov_hist[0].at(SpvOpCapability).size());
+ EXPECT_EQ(1u, stats.opcode_markov_hist[0].at(SpvOpExtension).size());
+ EXPECT_EQ(1u, stats.opcode_markov_hist[0].at(SpvOpMemoryModel).size());
+ EXPECT_EQ(2u, stats.opcode_markov_hist[0].at(SpvOpTypeInt).size());
+ EXPECT_EQ(
+ 4u, stats.opcode_markov_hist[0].at(SpvOpCapability).at(SpvOpCapability));
+ EXPECT_EQ(
+ 1u, stats.opcode_markov_hist[0].at(SpvOpCapability).at(SpvOpExtension));
+ EXPECT_EQ(
+ 1u, stats.opcode_markov_hist[0].at(SpvOpCapability).at(SpvOpMemoryModel));
+ EXPECT_EQ(
+ 1u, stats.opcode_markov_hist[0].at(SpvOpExtension).at(SpvOpMemoryModel));
+ EXPECT_EQ(
+ 1u, stats.opcode_markov_hist[0].at(SpvOpMemoryModel).at(SpvOpTypeInt));
+ EXPECT_EQ(
+ 1u, stats.opcode_markov_hist[0].at(SpvOpTypeInt).at(SpvOpTypeInt));
+ EXPECT_EQ(
+ 1u, stats.opcode_markov_hist[0].at(SpvOpTypeInt).at(SpvOpTypeFloat));
+
+ EXPECT_EQ(3u, stats.opcode_markov_hist[1].size());
+ EXPECT_EQ(4u, stats.opcode_markov_hist[1].at(SpvOpCapability).size());
+ EXPECT_EQ(1u, stats.opcode_markov_hist[1].at(SpvOpMemoryModel).size());
+ EXPECT_EQ(1u, stats.opcode_markov_hist[1].at(SpvOpTypeInt).size());
+ EXPECT_EQ(
+ 2u, stats.opcode_markov_hist[1].at(SpvOpCapability).at(SpvOpCapability));
+ EXPECT_EQ(
+ 1u, stats.opcode_markov_hist[1].at(SpvOpCapability).at(SpvOpExtension));
+ EXPECT_EQ(
+ 2u, stats.opcode_markov_hist[1].at(SpvOpCapability).at(SpvOpMemoryModel));
+ EXPECT_EQ(
+ 1u, stats.opcode_markov_hist[1].at(SpvOpCapability).at(SpvOpTypeInt));
+ EXPECT_EQ(
+ 1u, stats.opcode_markov_hist[1].at(SpvOpMemoryModel).at(SpvOpTypeInt));
+ EXPECT_EQ(
+ 1u, stats.opcode_markov_hist[1].at(SpvOpTypeInt).at(SpvOpTypeFloat));
+}
+
+TEST(AggregateStats, ConstantLiteralsHistogram) {
+ const std::string code1 = R"(
+OpCapability Addresses
+OpCapability Kernel
+OpCapability GenericPointer
+OpCapability Linkage
+OpCapability Float64
+OpCapability Int16
+OpCapability Int64
+OpMemoryModel Physical32 OpenCL
+%u16 = OpTypeInt 16 0
+%u32 = OpTypeInt 32 0
+%u64 = OpTypeInt 64 0
+%f32 = OpTypeFloat 32
+%f64 = OpTypeFloat 64
+%1 = OpConstant %f32 0.1
+%2 = OpConstant %f32 -2
+%3 = OpConstant %f64 -2
+%4 = OpConstant %u16 16
+%5 = OpConstant %u16 2
+%6 = OpConstant %u32 32
+%7 = OpConstant %u64 64
+)";
+
+ const std::string code2 = R"(
+OpCapability Shader
+OpCapability Linkage
+OpCapability Int16
+OpCapability Int64
+OpMemoryModel Logical GLSL450
+%f32 = OpTypeFloat 32
+%u16 = OpTypeInt 16 0
+%s16 = OpTypeInt 16 1
+%u32 = OpTypeInt 32 0
+%s32 = OpTypeInt 32 1
+%u64 = OpTypeInt 64 0
+%s64 = OpTypeInt 64 1
+%1 = OpConstant %f32 0.1
+%2 = OpConstant %f32 -2
+%3 = OpConstant %u16 1
+%4 = OpConstant %u16 16
+%5 = OpConstant %u16 2
+%6 = OpConstant %s16 -16
+%7 = OpConstant %u32 32
+%8 = OpConstant %s32 2
+%9 = OpConstant %s32 -32
+%10 = OpConstant %u64 64
+%11 = OpConstant %s64 -64
+)";
+
+ SpirvStats stats;
+
+ CompileAndAggregateStats(code1, &stats);
+ EXPECT_EQ(2u, stats.f32_constant_hist.size());
+ EXPECT_EQ(1u, stats.f64_constant_hist.size());
+ EXPECT_EQ(1u, stats.f32_constant_hist.at(0.1f));
+ EXPECT_EQ(1u, stats.f32_constant_hist.at(-2.f));
+ EXPECT_EQ(1u, stats.f64_constant_hist.at(-2));
+
+ EXPECT_EQ(2u, stats.u16_constant_hist.size());
+ EXPECT_EQ(0u, stats.s16_constant_hist.size());
+ EXPECT_EQ(1u, stats.u32_constant_hist.size());
+ EXPECT_EQ(0u, stats.s32_constant_hist.size());
+ EXPECT_EQ(1u, stats.u64_constant_hist.size());
+ EXPECT_EQ(0u, stats.s64_constant_hist.size());
+ EXPECT_EQ(1u, stats.u16_constant_hist.at(16));
+ EXPECT_EQ(1u, stats.u16_constant_hist.at(2));
+ EXPECT_EQ(1u, stats.u32_constant_hist.at(32));
+ EXPECT_EQ(1u, stats.u64_constant_hist.at(64));
+
+ CompileAndAggregateStats(code2, &stats);
+ EXPECT_EQ(2u, stats.f32_constant_hist.size());
+ EXPECT_EQ(1u, stats.f64_constant_hist.size());
+ EXPECT_EQ(2u, stats.f32_constant_hist.at(0.1f));
+ EXPECT_EQ(2u, stats.f32_constant_hist.at(-2.f));
+ EXPECT_EQ(1u, stats.f64_constant_hist.at(-2));
+
+ EXPECT_EQ(3u, stats.u16_constant_hist.size());
+ EXPECT_EQ(1u, stats.s16_constant_hist.size());
+ EXPECT_EQ(1u, stats.u32_constant_hist.size());
+ EXPECT_EQ(2u, stats.s32_constant_hist.size());
+ EXPECT_EQ(1u, stats.u64_constant_hist.size());
+ EXPECT_EQ(1u, stats.s64_constant_hist.size());
+ EXPECT_EQ(2u, stats.u16_constant_hist.at(16));
+ EXPECT_EQ(2u, stats.u16_constant_hist.at(2));
+ EXPECT_EQ(1u, stats.u16_constant_hist.at(1));
+ EXPECT_EQ(1u, stats.s16_constant_hist.at(-16));
+ EXPECT_EQ(2u, stats.u32_constant_hist.at(32));
+ EXPECT_EQ(1u, stats.s32_constant_hist.at(2));
+ EXPECT_EQ(1u, stats.s32_constant_hist.at(-32));
+ EXPECT_EQ(2u, stats.u64_constant_hist.at(64));
+ EXPECT_EQ(1u, stats.s64_constant_hist.at(-64));
+}
+
+} // namespace
diff --git a/test/stats/stats_analyzer_test.cpp b/test/stats/stats_analyzer_test.cpp
new file mode 100644
index 00000000..6dcaac4f
--- /dev/null
+++ b/test/stats/stats_analyzer_test.cpp
@@ -0,0 +1,172 @@
+// Copyright (c) 2017 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Tests for unique type declaration rules validator.
+
+#include <string>
+#include <sstream>
+
+#include "spirv/1.1/spirv.h"
+#include "test_fixture.h"
+#include "tools/stats/stats_analyzer.h"
+
+namespace {
+
+using libspirv::SpirvStats;
+
+// Fills |stats| with some synthetic header stats, as if aggregated from 100
+// modules (100 used for simpler percentage evaluation).
+void FillDefaultStats(SpirvStats* stats) {
+ *stats = SpirvStats();
+ stats->version_hist[0x00010000] = 40;
+ stats->version_hist[0x00010100] = 60;
+ stats->generator_hist[0x00000000] = 64;
+ stats->generator_hist[0x00010000] = 1;
+ stats->generator_hist[0x00020000] = 2;
+ stats->generator_hist[0x00030000] = 3;
+ stats->generator_hist[0x00040000] = 4;
+ stats->generator_hist[0x00050000] = 5;
+ stats->generator_hist[0x00060000] = 6;
+ stats->generator_hist[0x00070000] = 7;
+ stats->generator_hist[0x00080000] = 8;
+
+ int num_version_entries = 0;
+ for (const auto& pair : stats->version_hist) {
+ num_version_entries += pair.second;
+ }
+
+ int num_generator_entries = 0;
+ for (const auto& pair : stats->generator_hist) {
+ num_generator_entries += pair.second;
+ }
+
+ EXPECT_EQ(num_version_entries, num_generator_entries);
+}
+
+TEST(StatsAnalyzer, Version) {
+ SpirvStats stats;
+ FillDefaultStats(&stats);
+
+ StatsAnalyzer analyzer(stats);
+
+ std::stringstream ss;
+ analyzer.WriteVersion(ss);
+ const std::string output = ss.str();
+ const std::string expected_output = "Version 1.1 60%\nVersion 1.0 40%\n";
+
+ EXPECT_EQ(expected_output, output);
+}
+
+TEST(StatsAnalyzer, Generator) {
+ SpirvStats stats;
+ FillDefaultStats(&stats);
+
+ StatsAnalyzer analyzer(stats);
+
+ std::stringstream ss;
+ analyzer.WriteGenerator(ss);
+ const std::string output = ss.str();
+ const std::string expected_output =
+ "Khronos 64%\nKhronos Glslang Reference Front End 8%\n"
+ "Khronos SPIR-V Tools Assembler 7%\nKhronos LLVM/SPIR-V Translator 6%"
+ "\nARM 5%\nNVIDIA 4%\nCodeplay 3%\nValve 2%\nLunarG 1%\n";
+
+ EXPECT_EQ(expected_output, output);
+}
+
+TEST(StatsAnalyzer, Capability) {
+ SpirvStats stats;
+ FillDefaultStats(&stats);
+
+ stats.capability_hist[SpvCapabilityShader] = 25;
+ stats.capability_hist[SpvCapabilityKernel] = 75;
+
+ StatsAnalyzer analyzer(stats);
+
+ std::stringstream ss;
+ analyzer.WriteCapability(ss);
+ const std::string output = ss.str();
+ const std::string expected_output = "Kernel 75%\nShader 25%\n";
+
+ EXPECT_EQ(expected_output, output);
+}
+
+TEST(StatsAnalyzer, Extension) {
+ SpirvStats stats;
+ FillDefaultStats(&stats);
+
+ stats.extension_hist["greatest_extension_ever"] = 1;
+ stats.extension_hist["worst_extension_ever"] = 10;
+
+ StatsAnalyzer analyzer(stats);
+
+ std::stringstream ss;
+ analyzer.WriteExtension(ss);
+ const std::string output = ss.str();
+ const std::string expected_output =
+ "worst_extension_ever 10%\ngreatest_extension_ever 1%\n";
+
+ EXPECT_EQ(expected_output, output);
+}
+
+TEST(StatsAnalyzer, Opcode) {
+ SpirvStats stats;
+ FillDefaultStats(&stats);
+
+ stats.opcode_hist[SpvOpCapability] = 20;
+ stats.opcode_hist[SpvOpConstant] = 80;
+ stats.opcode_hist[SpvOpDecorate] = 100;
+
+ StatsAnalyzer analyzer(stats);
+
+ std::stringstream ss;
+ analyzer.WriteOpcode(ss);
+ const std::string output = ss.str();
+ const std::string expected_output =
+ "Total unique opcodes used: 3\nDecorate 50%\n"
+ "Constant 40%\nCapability 10%\n";
+
+ EXPECT_EQ(expected_output, output);
+}
+
+TEST(StatsAnalyzer, OpcodeMarkov) {
+ SpirvStats stats;
+ FillDefaultStats(&stats);
+
+ stats.opcode_hist[SpvOpFMul] = 400;
+ stats.opcode_hist[SpvOpFAdd] = 200;
+ stats.opcode_hist[SpvOpFSub] = 400;
+
+ stats.opcode_markov_hist.resize(1);
+ auto& hist = stats.opcode_markov_hist[0];
+ hist[SpvOpFMul][SpvOpFAdd] = 100;
+ hist[SpvOpFMul][SpvOpFSub] = 300;
+ hist[SpvOpFAdd][SpvOpFMul] = 100;
+ hist[SpvOpFAdd][SpvOpFAdd] = 100;
+
+ StatsAnalyzer analyzer(stats);
+
+ std::stringstream ss;
+ analyzer.WriteOpcodeMarkov(ss);
+ const std::string output = ss.str();
+ const std::string expected_output =
+ "FMul -> FSub 75% (base rate 40%, pair occurrences 300)\n"
+ "FMul -> FAdd 25% (base rate 20%, pair occurrences 100)\n"
+ "FAdd -> FAdd 50% (base rate 20%, pair occurrences 100)\n"
+ "FAdd -> FMul 50% (base rate 40%, pair occurrences 100)\n";
+
+ EXPECT_EQ(expected_output, output);
+}
+
+} // namespace
diff --git a/test/target_env_test.cpp b/test/target_env_test.cpp
index 4df5f943..0c8389da 100644
--- a/test/target_env_test.cpp
+++ b/test/target_env_test.cpp
@@ -41,14 +41,15 @@ TEST_P(TargetEnvTest, ValidDescription) {
TEST_P(TargetEnvTest, ValidSpirvVersion) {
auto spirv_version = spvVersionForTargetEnv(GetParam());
- ASSERT_THAT(spirv_version, AnyOf(0x10000, 0x10100));
+ ASSERT_THAT(spirv_version, AnyOf(0x10000, 0x10100, 0x10200));
}
INSTANTIATE_TEST_CASE_P(AllTargetEnvs, TargetEnvTest,
ValuesIn(spvtest::AllTargetEnvironments()));
TEST(GetContextTest, InvalidTargetEnvProducesNull) {
- spv_context context = spvContextCreate((spv_target_env)10);
+ // Use a value beyond the last valid enum value.
+ spv_context context = spvContextCreate(static_cast<spv_target_env>(15));
EXPECT_EQ(context, nullptr);
}
@@ -72,6 +73,7 @@ INSTANTIATE_TEST_CASE_P(TargetParsing, TargetParseTest,
ValuesIn(std::vector<ParseCase>{
{"spv1.0", true, SPV_ENV_UNIVERSAL_1_0},
{"spv1.1", true, SPV_ENV_UNIVERSAL_1_1},
+ {"spv1.2", true, SPV_ENV_UNIVERSAL_1_2},
{"vulkan1.0", true, SPV_ENV_VULKAN_1_0},
{"opencl2.1", true, SPV_ENV_OPENCL_2_1},
{"opencl2.2", true, SPV_ENV_OPENCL_2_2},
diff --git a/test/text_to_binary.debug_test.cpp b/test/text_to_binary.debug_test.cpp
index cefaaec4..ec2acab2 100644
--- a/test/text_to_binary.debug_test.cpp
+++ b/test/text_to_binary.debug_test.cpp
@@ -54,6 +54,8 @@ const LanguageCase kLanguageCases[] = {
CASE(OpenCL_C, 200),
CASE(OpenCL_C, 210),
CASE(OpenCL_CPP, 210),
+ CASE(HLSL, 5),
+ CASE(HLSL, 6),
#undef CASE
};
// clang-format on
diff --git a/test/text_to_binary.extension_test.cpp b/test/text_to_binary.extension_test.cpp
index 7614034c..45d1a23c 100644
--- a/test/text_to_binary.extension_test.cpp
+++ b/test/text_to_binary.extension_test.cpp
@@ -96,6 +96,17 @@ struct AssemblyCase {
std::vector<uint32_t> expected;
};
+using ExtensionAssemblyTest = spvtest::TextToBinaryTestBase<
+ ::testing::TestWithParam<std::tuple<spv_target_env, AssemblyCase>>>;
+
+TEST_P(ExtensionAssemblyTest, Samples) {
+ const spv_target_env& env = std::get<0>(GetParam());
+ const AssemblyCase& ac = std::get<1>(GetParam());
+
+ // Check that it assembles correctly.
+ EXPECT_THAT(CompiledInstructions(ac.input, env), Eq(ac.expected));
+}
+
using ExtensionRoundTripTest = spvtest::TextToBinaryTestBase<
::testing::TestWithParam<std::tuple<spv_target_env, AssemblyCase>>>;
@@ -197,10 +208,17 @@ INSTANTIATE_TEST_CASE_P(
Combine(Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1,
SPV_ENV_VULKAN_1_0),
ValuesIn(std::vector<AssemblyCase>{
- {"OpCapability StorageUniformBufferBlock16\n",
+ {"OpCapability StorageBuffer16BitAccess\n",
MakeInstruction(SpvOpCapability,
{SpvCapabilityStorageUniformBufferBlock16})},
- {"OpCapability StorageUniform16\n",
+ {"OpCapability StorageBuffer16BitAccess\n",
+ MakeInstruction(SpvOpCapability,
+ {SpvCapabilityStorageBuffer16BitAccess})},
+ {"OpCapability UniformAndStorageBuffer16BitAccess\n",
+ MakeInstruction(
+ SpvOpCapability,
+ {SpvCapabilityUniformAndStorageBuffer16BitAccess})},
+ {"OpCapability UniformAndStorageBuffer16BitAccess\n",
MakeInstruction(SpvOpCapability,
{SpvCapabilityStorageUniform16})},
{"OpCapability StoragePushConstant16\n",
@@ -211,6 +229,22 @@ INSTANTIATE_TEST_CASE_P(
{SpvCapabilityStorageInputOutput16})},
})), );
+INSTANTIATE_TEST_CASE_P(
+ SPV_KHR_16bit_storage_alias_check, ExtensionAssemblyTest,
+ Combine(Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1,
+ SPV_ENV_VULKAN_1_0),
+ ValuesIn(std::vector<AssemblyCase>{
+ // The old name maps to the new enum.
+ {"OpCapability StorageUniformBufferBlock16\n",
+ MakeInstruction(SpvOpCapability,
+ {SpvCapabilityStorageBuffer16BitAccess})},
+ // The old name maps to the new enum.
+ {"OpCapability StorageUniform16\n",
+ MakeInstruction(
+ SpvOpCapability,
+ {SpvCapabilityUniformAndStorageBuffer16BitAccess})},
+ })), );
+
// SPV_KHR_device_group
INSTANTIATE_TEST_CASE_P(
@@ -243,6 +277,7 @@ INSTANTIATE_TEST_CASE_P(
SpvBuiltInViewIndex})},
})), );
+
// SPV_AMD_gcn_shader
#define PREAMBLE "%1 = OpExtInstImport \"SPV_AMD_gcn_shader\"\n"
@@ -268,4 +303,23 @@ INSTANTIATE_TEST_CASE_P(
})), );
#undef PREAMBLE
+
+// SPV_KHR_variable_pointers
+
+INSTANTIATE_TEST_CASE_P(
+ SPV_KHR_variable_pointers, ExtensionRoundTripTest,
+ // We'll get coverage over operand tables by trying the universal
+ // environments, and at least one specific environment.
+ Combine(Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1,
+ SPV_ENV_VULKAN_1_0),
+ ValuesIn(std::vector<AssemblyCase>{
+ {"OpCapability VariablePointers\n",
+ MakeInstruction(SpvOpCapability,
+ {SpvCapabilityVariablePointers})},
+ {"OpCapability VariablePointersStorageBuffer\n",
+ MakeInstruction(
+ SpvOpCapability,
+ {SpvCapabilityVariablePointersStorageBuffer})},
+ })), );
+
} // anonymous namespace
diff --git a/test/text_to_binary.type_declaration_test.cpp b/test/text_to_binary.type_declaration_test.cpp
index cc2b1656..20f3797e 100644
--- a/test/text_to_binary.type_declaration_test.cpp
+++ b/test/text_to_binary.type_declaration_test.cpp
@@ -213,6 +213,7 @@ TEST_F(OpTypeForwardPointerTest, ValidStorageClass) {
CASE(PushConstant);
CASE(AtomicCounter);
CASE(Image);
+ CASE(StorageBuffer);
}
#undef CASE
diff --git a/test/unit_spirv.h b/test/unit_spirv.h
index 823fd78e..47448158 100644
--- a/test/unit_spirv.h
+++ b/test/unit_spirv.h
@@ -211,7 +211,7 @@ inline std::vector<spv_target_env> AllTargetEnvironments() {
return {SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1, SPV_ENV_OPENCL_2_1,
SPV_ENV_OPENCL_2_2, SPV_ENV_VULKAN_1_0, SPV_ENV_OPENGL_4_0,
SPV_ENV_OPENGL_4_1, SPV_ENV_OPENGL_4_2, SPV_ENV_OPENGL_4_3,
- SPV_ENV_OPENGL_4_5};
+ SPV_ENV_OPENGL_4_5, SPV_ENV_UNIVERSAL_1_2};
}
// Returns the capabilities in a CapabilitySet as an ordered vector.
diff --git a/test/val/val_id_test.cpp b/test/val/val_id_test.cpp
index 754420a1..e1bd5775 100644
--- a/test/val/val_id_test.cpp
+++ b/test/val/val_id_test.cpp
@@ -1805,6 +1805,163 @@ TEST_F(ValidateIdWithMessage, OpLoadGood) {
CompileSuccessfully(spirv.c_str());
EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
}
+
+// TODO: Add tests that exercise VariablePointersStorageBuffer instead of
+// VariablePointers.
+void createVariablePointerSpirvProgram(std::ostringstream* spirv,
+ std::string result_strategy,
+ bool use_varptr_cap,
+ bool add_helper_function) {
+ *spirv << "OpCapability Shader ";
+ if (use_varptr_cap) {
+ *spirv << "OpCapability VariablePointers ";
+ *spirv << "OpExtension \"SPV_KHR_variable_pointers\" ";
+ }
+ *spirv << R"(
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %main "main"
+ %void = OpTypeVoid
+ %voidf = OpTypeFunction %void
+ %bool = OpTypeBool
+ %i32 = OpTypeInt 32 1
+ %f32 = OpTypeFloat 32
+ %f32ptr = OpTypePointer Uniform %f32
+ %i = OpConstant %i32 1
+ %zero = OpConstant %i32 0
+ %float_1 = OpConstant %f32 1.0
+ %ptr1 = OpVariable %f32ptr Uniform
+ %ptr2 = OpVariable %f32ptr Uniform
+ )";
+ if (add_helper_function) {
+ *spirv << R"(
+ ; ////////////////////////////////////////////////////////////
+ ;;;; Function that returns a pointer
+ ; ////////////////////////////////////////////////////////////
+ %selector_func_type = OpTypeFunction %f32ptr %bool %f32ptr %f32ptr
+ %choose_input_func = OpFunction %f32ptr None %selector_func_type
+ %is_neg_param = OpFunctionParameter %bool
+ %first_ptr_param = OpFunctionParameter %f32ptr
+ %second_ptr_param = OpFunctionParameter %f32ptr
+ %selector_func_begin = OpLabel
+ %result_ptr = OpSelect %f32ptr %is_neg_param %first_ptr_param %second_ptr_param
+ OpReturnValue %result_ptr
+ OpFunctionEnd
+ )";
+ }
+ *spirv << R"(
+ %main = OpFunction %void None %voidf
+ %label = OpLabel
+ )";
+ *spirv << result_strategy;
+ *spirv << R"(
+ OpReturn
+ OpFunctionEnd
+ )";
+}
+
+// With the VariablePointer Capability, OpLoad should allow loading a
+// VaiablePointer. In this test the variable pointer is obtained by an OpSelect
+TEST_F(ValidateIdWithMessage, OpLoadVarPtrOpSelectGood) {
+ std::string result_strategy = R"(
+ %isneg = OpSLessThan %bool %i %zero
+ %varptr = OpSelect %f32ptr %isneg %ptr1 %ptr2
+ %result = OpLoad %f32 %varptr
+ )";
+
+ std::ostringstream spirv;
+ createVariablePointerSpirvProgram(&spirv, result_strategy,
+ true /* Add VariablePointers Capability? */,
+ false /* Use Helper Function? */);
+ CompileSuccessfully(spirv.str());
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+// Without the VariablePointers Capability, OpLoad will not allow loading
+// through a variable pointer.
+TEST_F(ValidateIdWithMessage, OpLoadVarPtrOpSelectBad) {
+ std::string result_strategy = R"(
+ %isneg = OpSLessThan %bool %i %zero
+ %varptr = OpSelect %f32ptr %isneg %ptr1 %ptr2
+ %result = OpLoad %f32 %varptr
+ )";
+
+ std::ostringstream spirv;
+ createVariablePointerSpirvProgram(&spirv, result_strategy,
+ false /* Add VariablePointers Capability?*/,
+ false /* Use Helper Function? */);
+ CompileSuccessfully(spirv.str());
+ EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(), HasSubstr("is not a logical pointer."));
+}
+
+// With the VariablePointer Capability, OpLoad should allow loading a
+// VaiablePointer. In this test the variable pointer is obtained by an OpPhi
+TEST_F(ValidateIdWithMessage, OpLoadVarPtrOpPhiGood) {
+ std::string result_strategy = R"(
+ %is_neg = OpSLessThan %bool %i %zero
+ OpSelectionMerge %end_label None
+ OpBranchConditional %is_neg %take_ptr_1 %take_ptr_2
+ %take_ptr_1 = OpLabel
+ OpBranch %end_label
+ %take_ptr_2 = OpLabel
+ OpBranch %end_label
+ %end_label = OpLabel
+ %varptr = OpPhi %f32ptr %ptr1 %take_ptr_1 %ptr2 %take_ptr_2
+ %result = OpLoad %f32 %varptr
+ )";
+
+ std::ostringstream spirv;
+ createVariablePointerSpirvProgram(&spirv, result_strategy,
+ true /* Add VariablePointers Capability?*/,
+ false /* Use Helper Function? */);
+ CompileSuccessfully(spirv.str());
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+// Without the VariablePointers Capability, OpLoad will not allow loading
+// through a variable pointer.
+TEST_F(ValidateIdWithMessage, OpLoadVarPtrOpPhiBad) {
+ std::string result_strategy = R"(
+ %is_neg = OpSLessThan %bool %i %zero
+ OpSelectionMerge %end_label None
+ OpBranchConditional %is_neg %take_ptr_1 %take_ptr_2
+ %take_ptr_1 = OpLabel
+ OpBranch %end_label
+ %take_ptr_2 = OpLabel
+ OpBranch %end_label
+ %end_label = OpLabel
+ %varptr = OpPhi %f32ptr %ptr1 %take_ptr_1 %ptr2 %take_ptr_2
+ %result = OpLoad %f32 %varptr
+ )";
+
+ std::ostringstream spirv;
+ createVariablePointerSpirvProgram(&spirv, result_strategy,
+ false /* Add VariablePointers Capability?*/,
+ false /* Use Helper Function? */);
+ CompileSuccessfully(spirv.str());
+ EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(), HasSubstr("is not a logical pointer"));
+}
+
+// With the VariablePointer Capability, OpLoad should allow loading through a
+// VaiablePointer. In this test the variable pointer is obtained from an
+// OpFunctionCall (return value from a function)
+TEST_F(ValidateIdWithMessage, OpLoadVarPtrOpFunctionCallGood) {
+ std::ostringstream spirv;
+ std::string result_strategy = R"(
+ %isneg = OpSLessThan %bool %i %zero
+ %varptr = OpFunctionCall %f32ptr %choose_input_func %isneg %ptr1 %ptr2
+ %result = OpLoad %f32 %varptr
+ )";
+
+ createVariablePointerSpirvProgram(&spirv,
+ result_strategy,
+ true /* Add VariablePointers Capability?*/,
+ true /* Use Helper Function? */);
+ CompileSuccessfully(spirv.str());
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
TEST_F(ValidateIdWithMessage, OpLoadResultTypeBad) {
string spirv = kGLSL450MemoryModel + R"(
%1 = OpTypeVoid
@@ -1927,6 +2084,41 @@ TEST_F(ValidateIdWithMessage, OpStoreLogicalPointerBad) {
HasSubstr("OpStore Pointer <id> '10' is not a logical pointer."));
}
+// Without the VariablePointer Capability, OpStore should may not store
+// through a variable pointer.
+TEST_F(ValidateIdWithMessage, OpStoreVarPtrBad) {
+ std::string result_strategy = R"(
+ %isneg = OpSLessThan %bool %i %zero
+ %varptr = OpSelect %f32ptr %isneg %ptr1 %ptr2
+ OpStore %varptr %float_1
+ )";
+
+ std::ostringstream spirv;
+ createVariablePointerSpirvProgram(
+ &spirv, result_strategy, false /* Add VariablePointers Capability? */,
+ false /* Use Helper Function? */);
+ CompileSuccessfully(spirv.str());
+ EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(), HasSubstr("is not a logical pointer."));
+}
+
+// With the VariablePointer Capability, OpStore should allow storing through a
+// variable pointer.
+TEST_F(ValidateIdWithMessage, OpStoreVarPtrGood) {
+ std::string result_strategy = R"(
+ %isneg = OpSLessThan %bool %i %zero
+ %varptr = OpSelect %f32ptr %isneg %ptr1 %ptr2
+ OpStore %varptr %float_1
+ )";
+
+ std::ostringstream spirv;
+ createVariablePointerSpirvProgram(&spirv, result_strategy,
+ true /* Add VariablePointers Capability? */,
+ false /* Use Helper Function? */);
+ CompileSuccessfully(spirv.str());
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
TEST_F(ValidateIdWithMessage, OpStoreObjectGood) {
string spirv = kGLSL450MemoryModel + R"(
%1 = OpTypeVoid
@@ -3614,6 +3806,33 @@ TEST_F(ValidateIdWithMessage, OpReturnValueIsVariableInLogical) {
"which is invalid in the Logical addressing model."));
}
+// With the VariablePointer Capability, the return value of a function is
+// allowed to be a pointer.
+TEST_F(ValidateIdWithMessage, OpReturnValueVarPtrGood) {
+ std::ostringstream spirv;
+ createVariablePointerSpirvProgram(&spirv,
+ "" /* Instructions to add to "main" */,
+ true /* Add VariablePointers Capability?*/,
+ true /* Use Helper Function? */);
+ CompileSuccessfully(spirv.str());
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+// Without the VariablePointer Capability, the return value of a function is
+// *not* allowed to be a pointer.
+TEST_F(ValidateIdWithMessage, OpReturnValueVarPtrBad) {
+ std::ostringstream spirv;
+ createVariablePointerSpirvProgram(&spirv,
+ "" /* Instructions to add to "main" */,
+ false /* Add VariablePointers Capability?*/,
+ true /* Use Helper Function? */);
+ CompileSuccessfully(spirv.str());
+ EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("OpReturnValue value's type <id> '7' is a pointer, "
+ "which is invalid in the Logical addressing model."));
+}
+
// TODO: enable when this bug is fixed:
// https://cvs.khronos.org/bugzilla/show_bug.cgi?id=15404
TEST_F(ValidateIdWithMessage, DISABLED_OpReturnValueIsFunction) {
diff --git a/test/val/val_type_unique_test.cpp b/test/val/val_type_unique_test.cpp
index 2330582f..724ba3d6 100644
--- a/test/val/val_type_unique_test.cpp
+++ b/test/val/val_type_unique_test.cpp
@@ -23,6 +23,7 @@
namespace {
using ::testing::HasSubstr;
+using ::testing::Not;
using std::string;
@@ -220,4 +221,21 @@ OpTypeForwardPointer %ptr2 Generic
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
}
+TEST_F(ValidateTypeUnique, duplicate_void_with_extension) {
+ string str = R"(
+OpCapability Addresses
+OpCapability Kernel
+OpCapability Linkage
+OpCapability Pipes
+OpExtension "SPV_VALIDATOR_ignore_type_decl_unique"
+OpMemoryModel Physical32 OpenCL
+%voidt = OpTypeVoid
+%voidt2 = OpTypeVoid
+)";
+ CompileSuccessfully(str.c_str());
+ ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(),
+ Not(HasSubstr(GetErrorString(SpvOpTypeVoid))));
+}
+
} // anonymous namespace
diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt
index ac680f89..43b9507d 100644
--- a/tools/CMakeLists.txt
+++ b/tools/CMakeLists.txt
@@ -33,7 +33,8 @@ function(add_spvtools_tool)
target_include_directories(${ARG_TARGET} PRIVATE
${spirv-tools_SOURCE_DIR}
${spirv-tools_BINARY_DIR}
-)
+ )
+ set_property(TARGET ${ARG_TARGET} PROPERTY FOLDER "SPIRV-Tools executables")
endfunction()
if (NOT ${SPIRV_SKIP_EXECUTABLES})
@@ -41,6 +42,10 @@ if (NOT ${SPIRV_SKIP_EXECUTABLES})
add_spvtools_tool(TARGET spirv-dis SRCS dis/dis.cpp LIBS ${SPIRV_TOOLS})
add_spvtools_tool(TARGET spirv-val SRCS val/val.cpp LIBS ${SPIRV_TOOLS})
add_spvtools_tool(TARGET spirv-opt SRCS opt/opt.cpp LIBS SPIRV-Tools-opt ${SPIRV_TOOLS})
+ add_spvtools_tool(TARGET spirv-stats
+ SRCS stats/stats.cpp
+ stats/stats_analyzer.cpp
+ LIBS ${SPIRV_TOOLS})
add_spvtools_tool(TARGET spirv-cfg
SRCS cfg/cfg.cpp
cfg/bin_to_dot.h
@@ -48,8 +53,10 @@ if (NOT ${SPIRV_SKIP_EXECUTABLES})
LIBS ${SPIRV_TOOLS})
target_include_directories(spirv-cfg PRIVATE ${spirv-tools_SOURCE_DIR}
${SPIRV_HEADER_INCLUDE_DIR})
+ target_include_directories(spirv-stats PRIVATE ${spirv-tools_SOURCE_DIR}
+ ${SPIRV_HEADER_INCLUDE_DIR})
- set(SPIRV_INSTALL_TARGETS spirv-as spirv-dis spirv-val spirv-opt spirv-cfg)
+ set(SPIRV_INSTALL_TARGETS spirv-as spirv-dis spirv-val spirv-opt spirv-stats spirv-cfg)
install(TARGETS ${SPIRV_INSTALL_TARGETS}
RUNTIME DESTINATION bin
LIBRARY DESTINATION lib
diff --git a/tools/as/as.cpp b/tools/as/as.cpp
index 51f14fbb..eb9f9ffa 100644
--- a/tools/as/as.cpp
+++ b/tools/as/as.cpp
@@ -37,16 +37,23 @@ Options:
-o <filename> Set the output filename. Use '-' to mean stdout.
--version Display assembler version information.
- --target-env {vulkan1.0|spv1.0|spv1.1}
- Use Vulkan1.0/SPIR-V1.0/SPIR-V1.1 validation rules.
+ --preserve-numeric-ids
+ Numeric IDs in the binary will have the same values as in the
+ source. Non-numeric IDs are allocated by filling in the gaps,
+ starting with 1 and going up.
+ --target-env {vulkan1.0|spv1.0|spv1.1|spv1.2}
+ Use Vulkan1.0/SPIR-V1.0/SPIR-V1.1/SPIR-V1.2
)",
argv0, argv0);
}
+static const auto kDefaultEnvironment = SPV_ENV_UNIVERSAL_1_2;
+
int main(int argc, char** argv) {
const char* inFile = nullptr;
const char* outFile = nullptr;
- spv_target_env target_env = SPV_ENV_UNIVERSAL_1_1;
+ uint32_t options = 0;
+ spv_target_env target_env = kDefaultEnvironment;
for (int argi = 1; argi < argc; ++argi) {
if ('-' == argv[argi][0]) {
switch (argv[argi][1]) {
@@ -76,13 +83,16 @@ int main(int argc, char** argv) {
if (0 == strcmp(argv[argi], "--version")) {
printf("%s\n", spvSoftwareVersionDetailsString());
printf("Target: %s\n",
- spvTargetEnvDescription(SPV_ENV_UNIVERSAL_1_1));
+ spvTargetEnvDescription(kDefaultEnvironment));
return 0;
}
if (0 == strcmp(argv[argi], "--help")) {
print_usage(argv[0]);
return 0;
}
+ if (0 == strcmp(argv[argi], "--preserve-numeric-ids")) {
+ options |= SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS;
+ }
if (0 == strcmp(argv[argi], "--target-env")) {
if (argi + 1 < argc) {
const auto env_str = argv[++argi];
@@ -121,8 +131,8 @@ int main(int argc, char** argv) {
spv_binary binary;
spv_diagnostic diagnostic = nullptr;
spv_context context = spvContextCreate(target_env);
- spv_result_t error = spvTextToBinary(context, contents.data(),
- contents.size(), &binary, &diagnostic);
+ spv_result_t error = spvTextToBinaryWithOptions(
+ context, contents.data(), contents.size(), options, &binary, &diagnostic);
spvContextDestroy(context);
if (error) {
spvDiagnosticPrint(diagnostic);
diff --git a/tools/cfg/cfg.cpp b/tools/cfg/cfg.cpp
index b2fc86cb..f609ddef 100644
--- a/tools/cfg/cfg.cpp
+++ b/tools/cfg/cfg.cpp
@@ -45,6 +45,8 @@ Options:
argv0, argv0);
}
+static const auto kDefaultEnvironment = SPV_ENV_UNIVERSAL_1_2;
+
int main(int argc, char** argv) {
const char* inFile = nullptr;
const char* outFile = nullptr; // Stays nullptr if printing to stdout.
@@ -71,7 +73,7 @@ int main(int argc, char** argv) {
} else if (0 == strcmp(argv[argi], "--version")) {
printf("%s EXPERIMENTAL\n", spvSoftwareVersionDetailsString());
printf("Target: %s\n",
- spvTargetEnvDescription(SPV_ENV_UNIVERSAL_1_1));
+ spvTargetEnvDescription(kDefaultEnvironment));
return 0;
} else {
print_usage(argv[0]);
@@ -104,7 +106,7 @@ int main(int argc, char** argv) {
// Read the input binary.
std::vector<uint32_t> contents;
if (!ReadFile<uint32_t>(inFile, "rb", &contents)) return 1;
- spv_context context = spvContextCreate(SPV_ENV_UNIVERSAL_1_1);
+ spv_context context = spvContextCreate(kDefaultEnvironment);
spv_diagnostic diagnostic = nullptr;
std::stringstream ss;
diff --git a/tools/dis/dis.cpp b/tools/dis/dis.cpp
index 0f335840..226f733a 100644
--- a/tools/dis/dis.cpp
+++ b/tools/dis/dis.cpp
@@ -52,6 +52,8 @@ Options:
argv0, argv0);
}
+static const auto kDefaultEnvironment = SPV_ENV_UNIVERSAL_1_2;
+
int main(int argc, char** argv) {
const char* inFile = nullptr;
const char* outFile = nullptr;
@@ -97,7 +99,7 @@ int main(int argc, char** argv) {
} else if (0 == strcmp(argv[argi], "--version")) {
printf("%s\n", spvSoftwareVersionDetailsString());
printf("Target: %s\n",
- spvTargetEnvDescription(SPV_ENV_UNIVERSAL_1_1));
+ spvTargetEnvDescription(kDefaultEnvironment));
return 0;
} else {
print_usage(argv[0]);
@@ -160,7 +162,7 @@ int main(int argc, char** argv) {
spv_text text = nullptr;
spv_text* textOrNull = print_to_stdout ? nullptr : &text;
spv_diagnostic diagnostic = nullptr;
- spv_context context = spvContextCreate(SPV_ENV_UNIVERSAL_1_1);
+ spv_context context = spvContextCreate(kDefaultEnvironment);
spv_result_t error =
spvBinaryToText(context, contents.data(), contents.size(), options,
textOrNull, &diagnostic);
diff --git a/tools/opt/opt.cpp b/tools/opt/opt.cpp
index d9dea339..21dd6764 100644
--- a/tools/opt/opt.cpp
+++ b/tools/opt/opt.cpp
@@ -13,16 +13,16 @@
// limitations under the License.
#include <cstring>
+#include <algorithm>
#include <iostream>
#include <memory>
#include <sstream>
#include <vector>
+#include "opt/set_spec_constant_default_value_pass.h"
+#include "spirv-tools/optimizer.hpp"
+
#include "message.h"
-#include "source/opt/build_module.h"
-#include "source/opt/ir_loader.h"
-#include "source/opt/pass_manager.h"
-#include "source/opt/passes.h"
#include "tools/io.h"
using namespace spvtools;
@@ -61,8 +61,16 @@ Options:
e.g.: --set-spec-const-default-value "1:100 2:400"
--unify-const
Remove the duplicated constants.
- --inline-entry-points-all
- Exhaustively inline all function calls in entry points
+ --inline-entry-points-exhaustive
+ Exhaustively inline all function calls in entry point functions.
+ Currently does not inline calls to functions with multiple
+ returns.
+ --flatten-decorations
+ Replace decoration groups with repeated OpDecorate and
+ OpMemberDecorate instructions.
+ --compact-ids
+ Remap result ids to a compact range starting from %%1 and without
+ any gaps.
-h, --help Print this help.
--version Display optimizer version information.
)",
@@ -73,15 +81,15 @@ int main(int argc, char** argv) {
const char* in_file = nullptr;
const char* out_file = nullptr;
- spv_target_env target_env = SPV_ENV_UNIVERSAL_1_1;
+ spv_target_env target_env = SPV_ENV_UNIVERSAL_1_2;
- opt::PassManager pass_manager;
- pass_manager.SetMessageConsumer(
- [](spv_message_level_t level, const char* source,
- const spv_position_t& position, const char* message) {
- std::cerr << StringifyMessage(level, source, position, message)
- << std::endl;
- });
+ spvtools::Optimizer optimizer(target_env);
+ optimizer.SetMessageConsumer([](spv_message_level_t level, const char* source,
+ const spv_position_t& position,
+ const char* message) {
+ std::cerr << StringifyMessage(level, source, position, message)
+ << std::endl;
+ });
for (int argi = 1; argi < argc; ++argi) {
const char* cur_arg = argv[argi];
@@ -100,7 +108,7 @@ int main(int argc, char** argv) {
return 1;
}
} else if (0 == strcmp(cur_arg, "--strip-debug")) {
- pass_manager.AddPass<opt::StripDebugInfoPass>();
+ optimizer.RegisterPass(CreateStripDebugInfoPass());
} else if (0 == strcmp(cur_arg, "--set-spec-const-default-value")) {
if (++argi < argc) {
auto spec_ids_vals =
@@ -113,8 +121,8 @@ int main(int argc, char** argv) {
argv[argi]);
return 1;
}
- pass_manager.AddPass<opt::SetSpecConstantDefaultValuePass>(
- std::move(*spec_ids_vals));
+ optimizer.RegisterPass(
+ CreateSetSpecConstantDefaultValuePass(std::move(*spec_ids_vals)));
} else {
fprintf(
stderr,
@@ -122,15 +130,23 @@ int main(int argc, char** argv) {
return 1;
}
} else if (0 == strcmp(cur_arg, "--freeze-spec-const")) {
- pass_manager.AddPass<opt::FreezeSpecConstantValuePass>();
- } else if (0 == strcmp(cur_arg, "--inline-entry-points-all")) {
- pass_manager.AddPass<opt::InlinePass>();
+ optimizer.RegisterPass(CreateFreezeSpecConstantValuePass());
+ } else if (0 == strcmp(cur_arg, "--inline-entry-points-exhaustive")) {
+ optimizer.RegisterPass(CreateInlinePass());
+ } else if (0 == strcmp(cur_arg, "--convert-local-access-chains")) {
+ optimizer.RegisterPass(CreateLocalAccessChainConvertPass());
+ } else if (0 == strcmp(cur_arg, "--eliminate-local-single-block")) {
+ optimizer.RegisterPass(CreateLocalSingleBlockLoadStoreElimPass());
} else if (0 == strcmp(cur_arg, "--eliminate-dead-const")) {
- pass_manager.AddPass<opt::EliminateDeadConstantPass>();
+ optimizer.RegisterPass(CreateEliminateDeadConstantPass());
} else if (0 == strcmp(cur_arg, "--fold-spec-const-op-composite")) {
- pass_manager.AddPass<opt::FoldSpecConstantOpAndCompositePass>();
+ optimizer.RegisterPass(CreateFoldSpecConstantOpAndCompositePass());
} else if (0 == strcmp(cur_arg, "--unify-const")) {
- pass_manager.AddPass<opt::UnifyConstantPass>();
+ optimizer.RegisterPass(CreateUnifyConstantPass());
+ } else if (0 == strcmp(cur_arg, "--flatten-decorations")) {
+ optimizer.RegisterPass(CreateFlattenDecorationPass());
+ } else if (0 == strcmp(cur_arg, "--compact-ids")) {
+ optimizer.RegisterPass(CreateCompactIdsPass());
} else if ('\0' == cur_arg[1]) {
// Setting a filename of "-" to indicate stdin.
if (!in_file) {
@@ -158,14 +174,14 @@ int main(int argc, char** argv) {
return 1;
}
- std::vector<uint32_t> source;
- if (!ReadFile<uint32_t>(in_file, "rb", &source)) return 1;
+ std::vector<uint32_t> binary;
+ if (!ReadFile<uint32_t>(in_file, "rb", &binary)) return 1;
// Let's do validation first.
spv_context context = spvContextCreate(target_env);
spv_diagnostic diagnostic = nullptr;
- spv_const_binary_t binary = {source.data(), source.size()};
- spv_result_t error = spvValidate(context, &binary, &diagnostic);
+ spv_const_binary_t binary_struct = {binary.data(), binary.size()};
+ spv_result_t error = spvValidate(context, &binary_struct, &diagnostic);
if (error) {
spvDiagnosticPrint(diagnostic);
spvDiagnosticDestroy(diagnostic);
@@ -175,16 +191,13 @@ int main(int argc, char** argv) {
spvDiagnosticDestroy(diagnostic);
spvContextDestroy(context);
- std::unique_ptr<ir::Module> module = BuildModule(
- target_env, pass_manager.consumer(), source.data(), source.size());
- pass_manager.Run(module.get());
-
- std::vector<uint32_t> target;
- module->ToBinary(&target, /* skip_nop = */ true);
+ // By using the same vector as input and output, we save time in the case
+ // that there was no change.
+ bool ok = optimizer.Run(binary.data(), binary.size(), &binary);
- if (!WriteFile<uint32_t>(out_file, "wb", target.data(), target.size())) {
+ if (!WriteFile<uint32_t>(out_file, "wb", binary.data(), binary.size())) {
return 1;
}
- return 0;
+ return ok ? 0 : 1;
}
diff --git a/tools/stats/stats.cpp b/tools/stats/stats.cpp
new file mode 100644
index 00000000..51d61834
--- /dev/null
+++ b/tools/stats/stats.cpp
@@ -0,0 +1,158 @@
+// Copyright (c) 2017 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <cassert>
+#include <cstring>
+#include <iostream>
+#include <unordered_map>
+
+#include "source/spirv_stats.h"
+#include "source/table.h"
+#include "spirv-tools/libspirv.h"
+#include "stats_analyzer.h"
+#include "tools/io.h"
+
+using libspirv::SpirvStats;
+
+namespace {
+
+struct ScopedContext {
+ ScopedContext(spv_target_env env) : context(spvContextCreate(env)) {}
+ ~ScopedContext() { spvContextDestroy(context); }
+ spv_context context;
+};
+
+void PrintUsage(char* argv0) {
+ printf(
+ R"(%s - Collect statistics from one or more SPIR-V binary file(s).
+
+USAGE: %s [options] [<filepaths>]
+
+TIP: In order to collect statistics from all .spv files under current dir use
+find . -name "*.spv" -print0 | xargs -0 -s 2000000 %s
+
+Options:
+ -h, --help Print this help.
+ -v, --verbose Print additional info to stderr.
+)",
+ argv0, argv0, argv0);
+}
+
+void DiagnosticsMessageHandler(spv_message_level_t level, const char*,
+ const spv_position_t& position,
+ const char* message) {
+ switch (level) {
+ case SPV_MSG_FATAL:
+ case SPV_MSG_INTERNAL_ERROR:
+ case SPV_MSG_ERROR:
+ std::cerr << "error: " << position.index << ": " << message
+ << std::endl;
+ break;
+ case SPV_MSG_WARNING:
+ std::cout << "warning: " << position.index << ": " << message
+ << std::endl;
+ break;
+ case SPV_MSG_INFO:
+ std::cout << "info: " << position.index << ": " << message << std::endl;
+ break;
+ default:
+ break;
+ }
+}
+
+} // namespace
+
+int main(int argc, char** argv) {
+ bool continue_processing = true;
+ int return_code = 0;
+
+ bool verbose = false;
+
+ std::vector<const char*> paths;
+
+ for (int argi = 1; continue_processing && argi < argc; ++argi) {
+ const char* cur_arg = argv[argi];
+ if ('-' == cur_arg[0]) {
+ if (0 == strcmp(cur_arg, "--help") || 0 == strcmp(cur_arg, "-h")) {
+ PrintUsage(argv[0]);
+ continue_processing = false;
+ return_code = 0;
+ } else if (0 == strcmp(cur_arg, "--verbose") || 0 == strcmp(cur_arg, "-v")) {
+ verbose = true;
+ } else {
+ PrintUsage(argv[0]);
+ continue_processing = false;
+ return_code = 1;
+ }
+ } else {
+ paths.push_back(cur_arg);
+ }
+ }
+
+ // Exit if command line parsing was not successful.
+ if (!continue_processing) {
+ return return_code;
+ }
+
+ std::cerr << "Processing " << paths.size() << " files..." << std::endl;
+
+ ScopedContext ctx(SPV_ENV_UNIVERSAL_1_1);
+ SetContextMessageConsumer(ctx.context, DiagnosticsMessageHandler);
+
+ libspirv::SpirvStats stats;
+ stats.opcode_markov_hist.resize(1);
+
+ for (size_t index = 0; index < paths.size(); ++index) {
+ const size_t kMilestonePeriod = 1000;
+ if (verbose) {
+ if (index % kMilestonePeriod == kMilestonePeriod - 1)
+ std::cerr << "Processed " << index + 1 << " files..." << std::endl;
+ }
+
+ const char* path = paths[index];
+ std::vector<uint32_t> contents;
+ if (!ReadFile<uint32_t>(path, "rb", &contents)) return 1;
+
+ if (SPV_SUCCESS != libspirv::AggregateStats(
+ *ctx.context, contents.data(), contents.size(), nullptr, &stats)) {
+ std::cerr << "error: Failed to aggregate stats for " << path << std::endl;
+ return 1;
+ }
+ }
+
+ StatsAnalyzer analyzer(stats);
+
+ std::ostream& out = std::cout;
+
+ out << std::endl;
+ analyzer.WriteVersion(out);
+ analyzer.WriteGenerator(out);
+
+ out << std::endl;
+ analyzer.WriteCapability(out);
+
+ out << std::endl;
+ analyzer.WriteExtension(out);
+
+ out << std::endl;
+ analyzer.WriteOpcode(out);
+
+ out << std::endl;
+ analyzer.WriteOpcodeMarkov(out);
+
+ out << std::endl;
+ analyzer.WriteConstantLiterals(out);
+
+ return 0;
+}
diff --git a/tools/stats/stats_analyzer.cpp b/tools/stats/stats_analyzer.cpp
new file mode 100644
index 00000000..9e248a42
--- /dev/null
+++ b/tools/stats/stats_analyzer.cpp
@@ -0,0 +1,241 @@
+// Copyright (c) 2017 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "stats_analyzer.h"
+
+#include <algorithm>
+#include <iostream>
+#include <sstream>
+#include <vector>
+
+#include "source/enum_string_mapping.h"
+#include "source/opcode.h"
+#include "source/spirv_constant.h"
+#include "spirv/1.1/spirv.h"
+
+using libspirv::SpirvStats;
+
+namespace {
+
+std::string GetVersionString(uint32_t word) {
+ std::stringstream ss;
+ ss << "Version " << SPV_SPIRV_VERSION_MAJOR_PART(word)
+ << "." << SPV_SPIRV_VERSION_MINOR_PART(word);
+ return ss.str();
+}
+
+std::string GetGeneratorString(uint32_t word) {
+ return spvGeneratorStr(SPV_GENERATOR_TOOL_PART(word));
+}
+
+std::string GetOpcodeString(uint32_t word) {
+ return spvOpcodeString(static_cast<SpvOp>(word));
+}
+
+std::string GetCapabilityString(uint32_t word) {
+ return libspirv::CapabilityToString(static_cast<SpvCapability>(word));
+}
+
+template <class T>
+std::string KeyIsLabel(T key) {
+ std::stringstream ss;
+ ss << key;
+ return ss.str();
+}
+
+template <class Key>
+std::unordered_map<Key, double> GetRecall(
+ const std::unordered_map<Key, uint32_t>& hist, uint64_t total) {
+ std::unordered_map<Key, double> freq;
+ for (const auto& pair : hist) {
+ const double frequency =
+ static_cast<double>(pair.second) / static_cast<double>(total);
+ freq.emplace(pair.first, frequency);
+ }
+ return freq;
+}
+
+template <class Key>
+std::unordered_map<Key, double> GetPrevalence(
+ const std::unordered_map<Key, uint32_t>& hist) {
+ uint64_t total = 0;
+ for (const auto& pair : hist) {
+ total += pair.second;
+ }
+
+ return GetRecall(hist, total);
+}
+
+// Writes |freq| to |out| sorted by frequency in the following format:
+// LABEL3 70%
+// LABEL1 20%
+// LABEL2 10%
+// |label_from_key| is used to convert |Key| to label.
+template <class Key>
+void WriteFreq(std::ostream& out, const std::unordered_map<Key, double>& freq,
+ std::string (*label_from_key)(Key), double threshold = 0.001) {
+ std::vector<std::pair<Key, double>> sorted_freq(freq.begin(), freq.end());
+ std::sort(sorted_freq.begin(), sorted_freq.end(),
+ [](const std::pair<Key, double>& left,
+ const std::pair<Key, double>& right) {
+ return left.second > right.second;
+ });
+
+ for (const auto& pair : sorted_freq) {
+ if (pair.second < threshold)
+ break;
+ out << label_from_key(pair.first) << " " << pair.second * 100.0
+ << "%" << std::endl;
+ }
+}
+
+// Writes |hist| to |out| sorted by count in the following format:
+// LABEL3 100
+// LABEL1 50
+// LABEL2 10
+// |label_from_key| is used to convert |Key| to label.
+template <class Key>
+void WriteHist(std::ostream& out, const std::unordered_map<Key, uint32_t>& hist,
+ std::string (*label_from_key)(Key)) {
+ std::vector<std::pair<Key, uint32_t>> sorted_hist(hist.begin(), hist.end());
+ std::sort(sorted_hist.begin(), sorted_hist.end(),
+ [](const std::pair<Key, uint32_t>& left,
+ const std::pair<Key, uint32_t>& right) {
+ return left.second > right.second;
+ });
+
+ for (const auto& pair : sorted_hist) {
+ out << label_from_key(pair.first) << " " << pair.second << std::endl;
+ }
+}
+
+} // namespace
+
+StatsAnalyzer::StatsAnalyzer(const SpirvStats& stats) : stats_(stats) {
+ num_modules_ = 0;
+ for (const auto& pair : stats_.version_hist) {
+ num_modules_ += pair.second;
+ }
+
+ version_freq_ = GetRecall(stats_.version_hist, num_modules_);
+ generator_freq_ = GetRecall(stats_.generator_hist, num_modules_);
+ capability_freq_ = GetRecall(stats_.capability_hist, num_modules_);
+ extension_freq_ = GetRecall(stats_.extension_hist, num_modules_);
+ opcode_freq_ = GetPrevalence(stats_.opcode_hist);
+}
+
+void StatsAnalyzer::WriteVersion(std::ostream& out) {
+ WriteFreq(out, version_freq_, GetVersionString);
+}
+
+void StatsAnalyzer::WriteGenerator(std::ostream& out) {
+ WriteFreq(out, generator_freq_, GetGeneratorString);
+}
+
+void StatsAnalyzer::WriteCapability(std::ostream& out) {
+ WriteFreq(out, capability_freq_, GetCapabilityString);
+}
+
+void StatsAnalyzer::WriteExtension(std::ostream& out) {
+ WriteFreq(out, extension_freq_, KeyIsLabel);
+}
+
+void StatsAnalyzer::WriteOpcode(std::ostream& out) {
+ out << "Total unique opcodes used: " << opcode_freq_.size() << std::endl;
+ WriteFreq(out, opcode_freq_, GetOpcodeString);
+}
+
+void StatsAnalyzer::WriteConstantLiterals(std::ostream& out) {
+ out << "Constant literals" << std::endl;
+
+ out << "Float 32" << std::endl;
+ WriteFreq(out, GetPrevalence(stats_.f32_constant_hist), KeyIsLabel);
+
+ out << std::endl << "Float 64" << std::endl;
+ WriteFreq(out, GetPrevalence(stats_.f64_constant_hist), KeyIsLabel);
+
+ out << std::endl << "Unsigned int 16" << std::endl;
+ WriteFreq(out, GetPrevalence(stats_.u16_constant_hist), KeyIsLabel);
+
+ out << std::endl << "Signed int 16" << std::endl;
+ WriteFreq(out, GetPrevalence(stats_.s16_constant_hist), KeyIsLabel);
+
+ out << std::endl << "Unsigned int 32" << std::endl;
+ WriteFreq(out, GetPrevalence(stats_.u32_constant_hist), KeyIsLabel);
+
+ out << std::endl << "Signed int 32" << std::endl;
+ WriteFreq(out, GetPrevalence(stats_.s32_constant_hist), KeyIsLabel);
+
+ out << std::endl << "Unsigned int 64" << std::endl;
+ WriteFreq(out, GetPrevalence(stats_.u64_constant_hist), KeyIsLabel);
+
+ out << std::endl << "Signed int 64" << std::endl;
+ WriteFreq(out, GetPrevalence(stats_.s64_constant_hist), KeyIsLabel);
+}
+
+void StatsAnalyzer::WriteOpcodeMarkov(std::ostream& out) {
+ if (stats_.opcode_markov_hist.empty())
+ return;
+
+ const std::unordered_map<uint32_t, std::unordered_map<uint32_t, uint32_t>>&
+ cue_to_hist = stats_.opcode_markov_hist[0];
+
+ // Sort by prevalence of the opcodes in opcode_freq_ (descending).
+ std::vector<std::pair<uint32_t, std::unordered_map<uint32_t, uint32_t>>>
+ sorted_cue_to_hist(cue_to_hist.begin(), cue_to_hist.end());
+ std::sort(sorted_cue_to_hist.begin(), sorted_cue_to_hist.end(),
+ [this](
+ const std::pair<uint32_t,
+ std::unordered_map<uint32_t, uint32_t>>& left,
+ const std::pair<uint32_t,
+ std::unordered_map<uint32_t, uint32_t>>& right) {
+ const double lf = opcode_freq_[left.first];
+ const double rf = opcode_freq_[right.first];
+ if (lf == rf)
+ return right.first > left.first;
+ return lf > rf;
+ });
+
+ for (const auto& kv : sorted_cue_to_hist) {
+ const uint32_t cue = kv.first;
+ const double kFrequentEnoughToAnalyze = 0.0001;
+ if (opcode_freq_[cue] < kFrequentEnoughToAnalyze) continue;
+
+ const std::unordered_map<uint32_t, uint32_t>& hist = kv.second;
+
+ uint32_t total = 0;
+ for (const auto& pair : hist) {
+ total += pair.second;
+ }
+
+ std::vector<std::pair<uint32_t, uint32_t>>
+ sorted_hist(hist.begin(), hist.end());
+ std::sort(sorted_hist.begin(), sorted_hist.end(),
+ [](const std::pair<uint32_t, uint32_t>& left,
+ const std::pair<uint32_t, uint32_t>& right) {
+ if (left.second == right.second)
+ return right.first > left.first;
+ return left.second > right.second;
+ });
+
+ for (const auto& pair : sorted_hist) {
+ const double prior = opcode_freq_[pair.first];
+ const double posterior =
+ static_cast<double>(pair.second) / static_cast<double>(total);
+ out << GetOpcodeString(cue) << " -> " << GetOpcodeString(pair.first)
+ << " " << posterior * 100 << "% (base rate " << prior * 100
+ << "%, pair occurrences " << pair.second << ")" << std::endl;
+ }
+ }
+}
diff --git a/tools/stats/stats_analyzer.h b/tools/stats/stats_analyzer.h
new file mode 100644
index 00000000..c1ff1871
--- /dev/null
+++ b/tools/stats/stats_analyzer.h
@@ -0,0 +1,52 @@
+// Copyright (c) 2017 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef LIBSPIRV_TOOLS_STATS_STATS_ANALYZER_H_
+#define LIBSPIRV_TOOLS_STATS_STATS_ANALYZER_H_
+
+#include <unordered_map>
+
+#include "source/spirv_stats.h"
+
+class StatsAnalyzer {
+ public:
+ explicit StatsAnalyzer(const libspirv::SpirvStats& stats);
+
+ // Writes respective histograms to |out|.
+ void WriteVersion(std::ostream& out);
+ void WriteGenerator(std::ostream& out);
+ void WriteCapability(std::ostream& out);
+ void WriteExtension(std::ostream& out);
+ void WriteOpcode(std::ostream& out);
+ void WriteConstantLiterals(std::ostream& out);
+
+ // Writes first order Markov analysis to |out|.
+ // stats_.opcode_markov_hist needs to contain raw data for at least one
+ // level.
+ void WriteOpcodeMarkov(std::ostream& out);
+
+ private:
+ const libspirv::SpirvStats& stats_;
+
+ uint32_t num_modules_;
+
+ std::unordered_map<uint32_t, double> version_freq_;
+ std::unordered_map<uint32_t, double> generator_freq_;
+ std::unordered_map<uint32_t, double> capability_freq_;
+ std::unordered_map<std::string, double> extension_freq_;
+ std::unordered_map<uint32_t, double> opcode_freq_;
+};
+
+
+#endif // LIBSPIRV_TOOLS_STATS_STATS_ANALYZER_H_
diff --git a/tools/val/val.cpp b/tools/val/val.cpp
index 270c42b1..2c065196 100644
--- a/tools/val/val.cpp
+++ b/tools/val/val.cpp
@@ -45,15 +45,15 @@ Options:
--max-control-flow-nesting-depth <maximum Control Flow nesting depth allowed>
--max-access-chain-indexes <maximum number of indexes allowed to use for Access Chain instructions>
--version Display validator version information.
- --target-env {vulkan1.0|spv1.0|spv1.1}
- Use Vulkan1.0/SPIR-V1.0/SPIR-V1.1 validation rules.
+ --target-env {vulkan1.0|spv1.0|spv1.1|spv1.2}
+ Use Vulkan1.0/SPIR-V1.0/SPIR-V1.1/SPIR-V1.2 validation rules.
)",
argv0, argv0);
}
int main(int argc, char** argv) {
const char* inFile = nullptr;
- spv_target_env target_env = SPV_ENV_UNIVERSAL_1_1;
+ spv_target_env target_env = SPV_ENV_UNIVERSAL_1_2;
spvtools::ValidatorOptions options;
bool continue_processing = true;
int return_code = 0;
@@ -85,9 +85,11 @@ int main(int argc, char** argv) {
}
} else if (0 == strcmp(cur_arg, "--version")) {
printf("%s\n", spvSoftwareVersionDetailsString());
- printf("Targets:\n %s\n %s\n",
+ // TODO(dneto): Add OpenCL 2.2 at least.
+ printf("Targets:\n %s\n %s\n %s\n",
spvTargetEnvDescription(SPV_ENV_UNIVERSAL_1_1),
- spvTargetEnvDescription(SPV_ENV_VULKAN_1_0));
+ spvTargetEnvDescription(SPV_ENV_VULKAN_1_0),
+ spvTargetEnvDescription(SPV_ENV_UNIVERSAL_1_2));
continue_processing = false;
return_code = 0;
} else if (0 == strcmp(cur_arg, "--help") || 0 == strcmp(cur_arg, "-h")) {
diff --git a/utils/generate_grammar_tables.py b/utils/generate_grammar_tables.py
index fc7e56ce..c7741f33 100755
--- a/utils/generate_grammar_tables.py
+++ b/utils/generate_grammar_tables.py
@@ -26,7 +26,14 @@ import re
PYGEN_VARIABLE_PREFIX = 'pygen_variable'
# Extensions to recognize, but which don't come from the SPIRV-V core grammar.
-NONSTANDARD_EXTENSIONS = ['SPV_AMD_gcn_shader',]
+NONSTANDARD_EXTENSIONS = [
+ 'SPV_AMD_gcn_shader',
+ # Validator would ignore type declaration unique check. Should only be used
+ # for legacy autogenerated test files containing multiple instances of the
+ # same type declaration, if fixing the test by other methods is too
+ # difficult. Shouldn't be used for any other reasons.
+ 'SPV_VALIDATOR_ignore_type_decl_unique',
+]
def make_path_to_file(f):
"""Makes all ancestor directories to the given file, if they
@@ -364,11 +371,13 @@ def get_extension_list(operands):
return sorted(set(extensions))
-def get_capability_list(operands):
- """Returns capabilities as a list of strings in the order of appearance."""
+def get_capabilities(operands):
+ """Returns capabilities as a list of JSON objects, in order of
+ appearance.
+ """
enumerants = sum([item.get('enumerants', []) for item in operands
if item.get('kind') in ['Capability']], [])
- return [item.get('enumerant') for item in enumerants]
+ return enumerants
def generate_extension_enum(operands):
@@ -399,7 +408,8 @@ def generate_string_to_extension_table(operands):
def generate_capability_to_string_table(operands):
"""Returns capability to string mapping table."""
- capabilities = get_capability_list(operands)
+ capabilities = [item.get('enumerant')
+ for item in get_capabilities(operands)]
entry_template = ' {{SpvCapability{capability},\n "{capability}"}}'
table_entries = [entry_template.format(capability=capability)
for capability in capabilities]
@@ -434,14 +444,19 @@ def generate_string_to_extension_mapping(operands):
def generate_capability_to_string_mapping(operands):
- """Returns mapping function from capabilities to corresponding strings."""
- capabilities = get_capability_list(operands)
+ """Returns mapping function from capabilities to corresponding strings.
+ We take care to avoid emitting duplicate values.
+ """
function = 'std::string CapabilityToString(SpvCapability capability) {\n'
function += ' switch (capability) {\n'
template = ' case SpvCapability{capability}:\n' \
' return "{capability}";\n'
- function += ''.join([template.format(capability=capability)
- for capability in capabilities])
+ emitted = set() # The values of capabilities we already have emitted
+ for capability in get_capabilities(operands):
+ value = capability.get('value')
+ if value not in emitted:
+ emitted.add(value)
+ function += template.format(capability=capability.get('enumerant'))
function += ' case SpvCapabilityMax:\n' \
' assert(0 && "Attempting to convert SpvCapabilityMax to string");\n' \
' return "";\n'